/*
 * 
 * This source code is part of 
 *   MARBLE (MoleculAR simulation package for BiomoLEcules)
 * 
 * Written by Mitsunori Ikeguchi
 * Copyright (c) 2012 Yokohama City University
 *  
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 */

#include <stdio.h>
#include <math.h>
#include <string.h>

#include "md_system.h"

/* Notice: Before using grp->rg or grp->vg, do RATTLE_calc_rg or RATTLE_calc_vg */

void RATTLE_init(RATTLE *rt)
{
  rt->flag = 0;
}

void RATTLE_init_time0(RATTLE *rt, ATOM_DATA *ad)
{
  /*
  RATTLE_calc_rg(rt,ad);
  RATTLE_calc_vg(rt,ad);
  */
}

void RATTLE_time_integration_v1(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad, double dt)
{
  int i,j;
  double hdt;

  hdt = 0.5 * dt;
#ifdef MPI_SDMD
  for (i = ad->node_atom_h; i>=0; i=ad->node_atom_n[i]) {
#else    
  for (i = 0; i<ad->natom; i++) {
#endif    
    if (!(ad->ex[i].flag & ATOM_RATTLE)) continue;
    ad->v[i].x += hdt * ad->f[i].x / ad->w[i];
    ad->v[i].y += hdt * ad->f[i].y / ad->w[i];
    ad->v[i].z += hdt * ad->f[i].z / ad->w[i];
  }
}

void RATTLE_time_integration_v2(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad, double dt)
{
  int i, j;
  double hdt;
  VEC *vg;

  hdt = 0.5 * dt;

#ifdef MPI_SDMD
  for (i = ad->node_atom_h; i>=0; i=ad->node_atom_n[i]) {
#else    
  for (i = 0; i<ad->natom; i++) {
#endif    
    if (!(ad->ex[i].flag & ATOM_RATTLE)) continue;
    ad->v[i].x += hdt * ad->f[i].x / ad->w[i];
    ad->v[i].y += hdt * ad->f[i].y / ad->w[i];
    ad->v[i].z += hdt * ad->f[i].z / ad->w[i];
  }
  RATTLE_b(rt, bd, ad);
}

void RATTLE_time_integration_p(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad, double dt)
{
  int i;
#ifdef MPI_SDMD
  for (i = ad->node_atom_h; i>=0; i=ad->node_atom_n[i]) {
#else    
  for (i = 0; i<ad->natom; i++) {
#endif    
    if (!(ad->ex[i].flag & ATOM_RATTLE)) continue;
    ad->px[i] = ad->x[i];
    ad->x[i].x += dt * ad->v[i].x;
    ad->x[i].y += dt * ad->v[i].y;
    ad->x[i].z += dt * ad->v[i].z;
  }

  RATTLE_a(rt, bd, ad, dt);
}

void RATTLE_time_integration_p_NPT(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad, double dt,
				   double AA2, double BB)
{
  int i,j;
  RATTLE_GROUP *grp;
  VEC rg_new;

  RATTLE_calc_rg(rt,ad);
  RATTLE_calc_vg(rt,ad);
  
  for (j=rt->node_grp_h;j>=0;j=rt->grp[j].node_grp_n) {
    grp = &rt->grp[j];
    rg_new.x = grp->rg.x * AA2 + grp->vg.x * BB;
    rg_new.y = grp->rg.y * AA2 + grp->vg.y * BB;
    rg_new.z = grp->rg.z * AA2 + grp->vg.z * BB;
    
    for (i=rt->grp[j].parent_atom;i>=0;i=ad->ex[i].child_list) {
      ad->px[i].x = ad->x[i].x - grp->rg.x + rg_new.x;
      ad->px[i].y = ad->x[i].y - grp->rg.y + rg_new.y;
      ad->px[i].z = ad->x[i].z - grp->rg.z + rg_new.z;
      
      ad->x[i].x = ad->px[i].x + dt * (ad->v[i].x - grp->vg.x);
      ad->x[i].y = ad->px[i].y + dt * (ad->v[i].y - grp->vg.y);
      ad->x[i].z = ad->px[i].z + dt * (ad->v[i].z - grp->vg.z);
    }
  }

  RATTLE_a(rt, bd, ad, dt);
}

void RATTLE_time_integration_p_NPT_full(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad, double dt,
					double Vg_vec[3][3], double AA2[3], double BB[3])
{
  int i,j;
  RATTLE_GROUP *grp;
  VEC ux, uv;
  VEC rg_new;

  RATTLE_calc_rg(rt,ad);
  RATTLE_calc_vg(rt,ad);
  
  for (j=rt->node_grp_h;j>=0;j=rt->grp[j].node_grp_n) {
    grp = &rt->grp[j];
    
    ux.x = VEC_MUL_MAT_X(grp->rg,Vg_vec);
    ux.y = VEC_MUL_MAT_Y(grp->rg,Vg_vec);
    ux.z = VEC_MUL_MAT_Z(grp->rg,Vg_vec);
    uv.x = VEC_MUL_MAT_X(grp->vg,Vg_vec);
    uv.y = VEC_MUL_MAT_Y(grp->vg,Vg_vec);
    uv.z = VEC_MUL_MAT_Z(grp->vg,Vg_vec);
    
    ux.x = ux.x * AA2[0] + uv.x * BB[0];
    ux.y = ux.y * AA2[1] + uv.y * BB[1];
    ux.z = ux.z * AA2[2] + uv.z * BB[2];

    rg_new.x = MAT_MUL_VEC_X(Vg_vec,ux);
    rg_new.y = MAT_MUL_VEC_Y(Vg_vec,ux);
    rg_new.z = MAT_MUL_VEC_Z(Vg_vec,ux);

    for (i=grp->parent_atom;i>=0;i=ad->ex[i].child_list) {
      ad->px[i].x = ad->x[i].x - grp->rg.x + rg_new.x;
      ad->px[i].y = ad->x[i].y - grp->rg.y + rg_new.y;
      ad->px[i].z = ad->x[i].z - grp->rg.z + rg_new.z;
      
      ad->x[i].x = ad->px[i].x + dt * (ad->v[i].x - grp->vg.x);
      ad->x[i].y = ad->px[i].y + dt * (ad->v[i].y - grp->vg.y);
      ad->x[i].z = ad->px[i].z + dt * (ad->v[i].z - grp->vg.z);
    }
  }
  
  RATTLE_a(rt, bd, ad, dt);
}

void RATTLE_scale_velocity(RATTLE *rt, ATOM_DATA *ad, double *scale_tr, double *scale_rot)
{
  int i,j,iex;
  RATTLE_GROUP *grp;
  double vx, vy, vz;

  RATTLE_calc_vg(rt,ad);
  
  for (j=rt->node_grp_h;j>=0;j=rt->grp[j].node_grp_n) {
    grp = &rt->grp[j];
    iex = ATOM_FLAG_EX_SYSTEM(ad->ex[grp->parent_atom].flag);
    vx = grp->vg.x * scale_tr[iex];
    vy = grp->vg.y * scale_tr[iex];
    vz = grp->vg.z * scale_tr[iex];
    
    for (i=grp->parent_atom;i>=0;i=ad->ex[i].child_list) {
      ad->v[i].x = vx + (ad->v[i].x - grp->vg.x) * scale_rot[iex];
      ad->v[i].y = vy + (ad->v[i].y - grp->vg.y) * scale_rot[iex];
      ad->v[i].z = vz + (ad->v[i].z - grp->vg.z) * scale_rot[iex];
    }
    
    grp->vg.x = vx;
    grp->vg.y = vy;
    grp->vg.z = vz;
  }
}

void RATTLE_scale_velocity_full(RATTLE *rt, ATOM_DATA *ad,
				double scale_tr[][3][3],
				double scale_rot[])
{
  int i,j,iex;
  RATTLE_GROUP *grp;
  double vx, vy, vz;

  RATTLE_calc_vg(rt,ad);
  
  for (j=rt->node_grp_h;j>=0;j=rt->grp[j].node_grp_n) {
    grp = &rt->grp[j];
    iex = ATOM_FLAG_EX_SYSTEM(ad->ex[grp->parent_atom].flag);
    
    vx = VEC_MUL_MAT_X(grp->vg,scale_tr[iex]);
    vy = VEC_MUL_MAT_Y(grp->vg,scale_tr[iex]);
    vz = VEC_MUL_MAT_Z(grp->vg,scale_tr[iex]);
    
    for (i=grp->parent_atom;i>=0;i=ad->ex[i].child_list) {
      ad->v[i].x = vx + (ad->v[i].x - grp->vg.x) * scale_rot[iex];
      ad->v[i].y = vy + (ad->v[i].y - grp->vg.y) * scale_rot[iex];
      ad->v[i].z = vz + (ad->v[i].z - grp->vg.z) * scale_rot[iex];
    }
    grp->vg.x = vx;
    grp->vg.y = vy;
    grp->vg.z = vz;
  }
}


void RATTLE_kene(RATTLE *rt, ATOM_DATA *ad,
		 double *kene_t, double *kene_r, double *kene_arr)
{
  double kt, kene_all, kene;
  int i;

  RATTLE_calc_vg(rt,ad);
  
  kt = kene_all = 0.0;
  for (i=rt->node_grp_h; i>=0; i=rt->grp[i].node_grp_n) {
    kt += rt->grp[i].w * Length2(rt->grp[i].vg.x, rt->grp[i].vg.y, rt->grp[i].vg.z);
  }
#ifdef MPI_SDMD
  for (i = ad->node_atom_h; i>=0; i=ad->node_atom_n[i]) {
#else    
  for (i = 0; i<ad->natom; i++) {
#endif    
    if (!(ad->ex[i].flag & ATOM_RATTLE)) continue;
    kene = ad->w[i] * Length2(ad->v[i].x,ad->v[i].y,ad->v[i].z);
    kene_arr[ATOM_FLAG_EX_SYSTEM(ad->ex[i].flag)] += kene;
    kene_all += kene;
  }
  *kene_t += kt;
  *kene_r += kene_all - kt;
}

void RATTLE_kene_full(RATTLE *rt, ATOM_DATA *ad,
		      double *kene_t, double *kene_r, int n_ex_system,
		      double kene_tr_arr[MAX_EX_SYSTEM][3][3],
		      double kene_rot_arr[MAX_EX_SYSTEM])
{
  double rt_kene_tr[MAX_EX_SYSTEM], rt_kene_all[MAX_EX_SYSTEM];
  double ktx, kty, ktz, kene;
  int i, iex;
  RATTLE_GROUP *grp;
  
  RATTLE_calc_vg(rt,ad);
  
  for (i=0;i<n_ex_system;i++) {
    rt_kene_tr[i] = rt_kene_all[i] = 0.0;
  }
    
  for (i=rt->node_grp_h; i>=0; i=rt->grp[i].node_grp_n) {
    grp = &rt->grp[i];
    iex = ATOM_FLAG_EX_SYSTEM(ad->ex[grp->parent_atom].flag);
    
    ktx = grp->w * grp->vg.x * grp->vg.x;
    kty = grp->w * grp->vg.y * grp->vg.y;
    ktz = grp->w * grp->vg.z * grp->vg.z;
    
    rt_kene_tr[iex] += ktx + kty + ktz;

    kene_tr_arr[iex][0][0] += ktx;
    kene_tr_arr[iex][0][1] += grp->w * grp->vg.x * grp->vg.y;
    kene_tr_arr[iex][0][2] += grp->w * grp->vg.x * grp->vg.z;
    kene_tr_arr[iex][1][1] += kty;
    kene_tr_arr[iex][1][2] += grp->w * grp->vg.y * grp->vg.z;
    kene_tr_arr[iex][2][2] += ktz;
  }
  
#ifdef MPI_SDMD
  for (i = ad->node_atom_h; i>=0; i=ad->node_atom_n[i]) {
#else    
  for (i = 0; i<ad->natom; i++) {
#endif    
    if (!(ad->ex[i].flag & ATOM_RATTLE)) continue;
    kene = ad->w[i] * Length2(ad->v[i].x,ad->v[i].y,ad->v[i].z);
    iex = ATOM_FLAG_EX_SYSTEM(ad->ex[i].flag);
    rt_kene_all[iex] += kene;
    if (kene > 100.0)
      printf("kita %f %f %f %f kene\n", kene, ad->v[i].x, ad->v[i].y, ad->v[i].z);
  }

  for (i=0;i<n_ex_system;i++) {
    kene_rot_arr[i] += rt_kene_all[i] - rt_kene_tr[i];

    *kene_r += rt_kene_all[i] - rt_kene_tr[i];
    *kene_t += rt_kene_tr[i];
  }
}

void RATTLE_calc_rg(RATTLE *rt, ATOM_DATA *ad)
{
  int i,j;
  VEC *rg;
  
  for (j=rt->node_grp_h;j>=0;j=rt->grp[j].node_grp_n) {
    rg = &rt->grp[j].rg;
    rg->x = rg->y = rg->z = 0.0;
    for (i=rt->grp[j].parent_atom;i>=0;i=ad->ex[i].child_list) {
      rg->x += ad->w[i] * ad->x[i].x;
      rg->y += ad->w[i] * ad->x[i].y;
      rg->z += ad->w[i] * ad->x[i].z;
    }
    rg->x /= rt->grp[j].w;
    rg->y /= rt->grp[j].w;
    rg->z /= rt->grp[j].w;
  }
}

void RATTLE_calc_vg(RATTLE *rt, ATOM_DATA *ad)
{
  int i,j;
  VEC *vg;
  
  for (j=rt->node_grp_h;j>=0;j=rt->grp[j].node_grp_n) {
    vg = &rt->grp[j].vg;
    vg->x = vg->y = vg->z = 0.0;
    for (i=rt->grp[j].parent_atom;i>=0;i=ad->ex[i].child_list) {
      vg->x += ad->w[i] * ad->v[i].x;
      vg->y += ad->w[i] * ad->v[i].y;
      vg->z += ad->w[i] * ad->v[i].z;
    }
    vg->x /= rt->grp[j].w;
    vg->y /= rt->grp[j].w;
    vg->z /= rt->grp[j].w;
  }
}

/*  RATTLE routine: first part of velocity verlet
    ad->x:   updated position   (t+dt)
    ad->px:  previous position  (t)
    ad->v:   updated position   (t+dt/2)
    dt:      time step
*/
void RATTLE_backup_x(ATOM_DATA *ad)
{
  int i;
#ifdef MPI_SDMD
  for (i = ad->node_atom_h; i>=0; i=ad->node_atom_n[i]) {
#else  
  for (i=0;i<ad->natom;i++) {
#endif
    if (!(ad->ex[i].flag & ATOM_RATTLE)) continue;
    ad->px[i] = ad->x[i];
  }
}

void RATTLE_a(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad, double dt)
{
  int loop, done;
  int i;          /* bond number */
  int ia, ib;     /* atom number */
  double tol2;
  double rabsq;
  double udx, udy, udz, uabsq, udiffsq;
  double pdx, pdy, pdz, puab;
  double rma, rmb, gab;
  double dx, dy, dz;
  
  tol2 = rt->tolerance*2.0;
       /*   |P**2 - R**2| < R**2 * tol2 
            (P+R)|P-R| <  2*R * R*tol  
            thus, tol2 = tol*2.0       */
  done = 0;
  loop = 0;
  /* At first, flags (id) are set. */
#ifdef MPI_SDMD
  for (ia = ad->node_atom_h; ia>=0; ia=ad->node_atom_n[ia]) {
#else  
  for (ia=0;ia<ad->natom;ia++) {
#endif    
    ad->ex[ia].id = 1;
  }
  while (!done && loop < rt->max_loop) {
    done = 1; loop++;
#ifdef MPI_SDMD
    for (i=bd->rhead; i>=0; i=bd->bonds[i].rnext) {
#else    
    for (i=0;i<bd->n_bond;i++) {
#endif      
#ifdef MPI_SDMD
      /* if (bd->bonds[i].flag & BOND_OTHER_NODE) continue; */
#endif
      if (bd->bonds[i].flag & RATTLE_FLAG) {
	ia = bd->bonds[i].atom1;
	ib = bd->bonds[i].atom2;
	if (ad->ex[ia].id >= 1 || ad->ex[ib].id >= 1) {
	  /* if either atom moves, enter here */
	  
	  rabsq = bd->bond_type[bd->bonds[i].type].r0;
	  rabsq *= rabsq;
	  
	  udx = ad->x[ia].x - ad->x[ib].x;
	  udy = ad->x[ia].y - ad->x[ib].y;
	  udz = ad->x[ia].z - ad->x[ib].z;
	  uabsq = udx*udx+udy*udy+udz*udz;
	  udiffsq = rabsq - uabsq;

	  /* debug 
	  printf("udiffsq %10f %10f %10f\n", udiffsq, rabsq, uabsq); */

	  if (fabs(udiffsq) > rabsq * tol2) {
	    /* if not enough length.. */

	    pdx = ad->px[ia].x - ad->px[ib].x;
	    pdy = ad->px[ia].y - ad->px[ib].y;
	    pdz = ad->px[ia].z - ad->px[ib].z;
	    puab = pdx*udx+pdy*udy+pdz*udz;

	    if (puab < rabsq * rt->tolerance) {
	      printf("ERROR: RATTLE FAILURE!\n");
	      return;
	    }

	    rma = 1.0/ad->w[ia];
	    rmb = 1.0/ad->w[ib];
	    gab = udiffsq / ( 2.0*(rma+rmb) * puab);
	    dx = pdx * gab;
	    dy = pdy * gab;
	    dz = pdz * gab;
	    ad->x[ia].x += rma * dx;
	    ad->x[ia].y += rma * dy;
	    ad->x[ia].z += rma * dz;
	    ad->x[ib].x -= rmb * dx;
	    ad->x[ib].y -= rmb * dy;
	    ad->x[ib].z -= rmb * dz;

	    dx /= dt;
	    dy /= dt;
	    dz /= dt;
	    
	    ad->v[ia].x += rma * dx;
	    ad->v[ia].y += rma * dy;
	    ad->v[ia].z += rma * dz;
	    ad->v[ib].x -= rmb * dx;
	    ad->v[ib].y -= rmb * dy;
	    ad->v[ib].z -= rmb * dz;

	    ad->ex[ia].id = ad->ex[ib].id = 2;
	    done = 0;
	  }
	}
      }
    }
#ifdef MPI_SDMD
    for (ia = ad->node_atom_h; ia>=0; ia=ad->node_atom_n[ia]) {
#else  
    for (ia=0;ia<ad->natom;ia++) {
#endif      
      if (ad->ex[ia].id == 2) {
	/* if the atom is involved in RATTLE */
	ad->ex[ia].id = 1;
      } else {
	/* if the atom is not involved in RATTLE */
	ad->ex[ia].id = 0;
      }
    }
  }
  /* lprintf("A loop = %d\n", loop);  */
}

/* second part of velocity verlet
   ad->x : postion (t+dt)
   ad->v : postion (t+dt/2)
*/
void RATTLE_b(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad)
{
  int done, loop;
  int ia, ib, i;
  double rabsq;
  double vxab, vyab, vzab;
  double rxab, ryab, rzab;
  double rma, rmb, gab, rvab;
  double dx, dy, dz;
  
  done = 0;
  loop = 0;
  /* At first, flags (id) are set. */
  for (ia=0;ia<ad->natom;ia++) {
    ad->ex[ia].id = 1;
  }
  while (!done && loop < rt->max_loop) {
    done = 1; loop++;
#ifdef MPI_SDMD
    for (i=bd->rhead; i>=0; i=bd->bonds[i].rnext) {
#else    
    for (i=0;i<bd->n_bond;i++) {
#endif      
#ifdef MPI_SDMD
      /*if (bd->bonds[i].flag & BOND_OTHER_NODE) continue;*/
#endif
      if (bd->bonds[i].flag & RATTLE_FLAG) {
	ia = bd->bonds[i].atom1;
	ib = bd->bonds[i].atom2;
	
	if (ad->ex[ia].id >= 1 || ad->ex[ib].id >= 1) { 
	  /* if either atom moves, enter here */
	
	  rabsq = bd->bond_type[bd->bonds[i].type].r0;
	  rabsq *= rabsq;

	  vxab = ad->v[ia].x - ad->v[ib].x;
	  vyab = ad->v[ia].y - ad->v[ib].y;
	  vzab = ad->v[ia].z - ad->v[ib].z;
	  
	  rxab = ad->x[ia].x - ad->x[ib].x;
	  ryab = ad->x[ia].y - ad->x[ib].y;
	  rzab = ad->x[ia].z - ad->x[ib].z;
	  rvab = rxab*vxab+ryab*vyab+rzab*vzab;
	  rma = 1.0/ad->w[ia];
	  rmb = 1.0/ad->w[ib];
	  
	  /* printf("rvab %e\n", rvab); */
	  gab = -rvab / ((rma+rmb) * rabsq);
	  if (fabs(gab) > rt->tolerance) {
	    /* Do RATTLE */
	    dx = rxab * gab;
	    dy = ryab * gab;
	    dz = rzab * gab;
	    ad->v[ia].x += rma * dx;
	    ad->v[ia].y += rma * dy;
	    ad->v[ia].z += rma * dz;
	    ad->v[ib].x -= rmb * dx;
	    ad->v[ib].y -= rmb * dy;
	    ad->v[ib].z -= rmb * dz;
	    
	    ad->ex[ia].id = ad->ex[ib].id = 2;
	    done = 0;
	  }
	}
      }
    }
#ifdef MPI_SDMD
    for (ia = ad->node_atom_h; ia>=0; ia=ad->node_atom_n[ia]) {
#else  
    for (ia=0;ia<ad->natom;ia++) {
#endif
      if (ad->ex[ia].id == 2) {
	/* if the atom is involved in RATTLE */
	ad->ex[ia].id = 1;
      } else {
	/* if the atom is involved in RATTLE */
	ad->ex[ia].id = 0;
      }
    }
  }
  /* lprintf("B loop = %d\n", loop);  */
}

void RATTLE_setup(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad,
		  RATTLE_TYPE type, double tol)
{
  int i;
  int prev;

  rt->flag = 1;
  rt->tolerance = tol;
  rt->max_loop = 1000;
  rt->n_bond = 0;
  
  lprintf("RATTLE Setup:\n");
  lprintf("  Target : ");
  
  switch (type) {
  case RATTLE_ALL:
#ifdef MPI_SDMD
    lprintf("RATTLE ALL is not supported in SDMD.\n");
    marble_exit(1);
#else    
    lprintf("all bonds\n");
    prev = -1;
    for (i=0;i<bd->n_bond;i++) {
      rt->n_bond++;
      bd->bonds[i].flag |= RATTLE_FLAG;
      bd->bonds[i].rnext = -1;
      if (prev < 0)
	bd->rhead = i;
      else
	bd->bonds[prev].rnext = i;
      prev = i;
    } 
    break;
#endif    
  case RATTLE_HYDR:
    lprintf("hydrogen-including bonds\n");
    prev = -1;
    for (i=0;i<bd->n_bond;i++) {
      if (bd->bonds[i].flag & HYDROGEN_INCLUDED) {
	rt->n_bond++;
	bd->bonds[i].flag |= RATTLE_FLAG;
	bd->bonds[i].rnext = -1;
	if (prev < 0)
	  bd->rhead = i;
	else
	  bd->bonds[prev].rnext = i;
	prev = i;
      } else {
	bd->bonds[i].flag &= ~RATTLE_FLAG;
      }
    }
    break;
  case RATTLE_WAT:
    lprintf("water bonds\n");
    prev = -1;
    for (i=0;i<bd->n_bond;i++) {
      if (bd->bonds[i].flag & WATER_BOND) {
	rt->n_bond++;
	bd->bonds[i].flag |= RATTLE_FLAG;
	bd->bonds[i].rnext = -1;
	if (prev < 0)
	  bd->rhead = i;
	else
	  bd->bonds[prev].rnext = i;
	prev = i;
      } else {
	bd->bonds[i].flag &= ~RATTLE_FLAG;
      }
    }
    break;
  case RATTLE_NONE:
    lprintf("none\n");
    rt->flag = 0;
    bd->rhead = -1;
    for (i=0;i<bd->n_bond;i++) {
      bd->bonds[i].flag &= ~RATTLE_FLAG;
    }
    break;
  }
  if (type != RATTLE_NONE) {
    RATTLE_setup_group(rt, bd, ad);
    lprintf("  Tolerance = %e\n",tol);
    lprintf("  Number of rattle bonds = %d\n",  rt->n_bond);
    lprintf("  Number of rattle groups = %d\n", rt->n_grp);
    lprintf("\n");
  }
}

void RATTLE_setup_group(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad)
{
  char *p1,*p2;
  int i, j, n, atom1, atom2, parent, child, atom1_parent, atom2_parent;
  
  for (i=0; i<ad->natom; i++) {
    if (ad->ex[i].flag & ATOM_RIGID) continue;
    ad->ex[i].child_list = -1;
    ad->ex[i].parent = i;
  }
  
  for (i=0;i<bd->n_bond;i++) {
    if (!(bd->bonds[i].flag & RATTLE_FLAG)) continue;
    /* Here, rattle bonds */
    /* Consider that we merge two groups... */
    atom1 = bd->bonds[i].atom1;
    atom2 = bd->bonds[i].atom2;

    if ((ad->ex[atom1].flag & ATOM_RIGID) ||
	(ad->ex[atom2].flag & ATOM_RIGID)) {
      lprintf("ERROR: One of atoms (%d,%d) involved in the rattle bond is in rigid body.\n",atom1+1,atom2+1);
      marble_exit(1);
    }

    if ((ad->ex[atom1].flag & ATOM_FIXED) ||
	(ad->ex[atom2].flag & ATOM_FIXED)) {
      lprintf("ERROR: One of atoms (%d,%d) involved in the rattle bond is fixed\n",atom1+1,atom2+1);
      marble_exit(1);
    }

    ad->ex[atom1].flag |= ATOM_RATTLE;
    ad->ex[atom2].flag |= ATOM_RATTLE;
    
    atom1_parent = ad->ex[atom1].parent;
    atom2_parent = ad->ex[atom2].parent;

    /* if parents are common, do nothing. */
    if (atom1_parent == atom2_parent) continue;

    p1 = get_atom_sym(ad,atom1_parent);
    p2 = get_atom_sym(ad,atom2_parent);

    if (*p1 == 'H' && *p2 == 'H') {
      /* Both are hydrogen */
      parent = atom1_parent;
      child  = atom2_parent;
    } else if (*p1 != 'H' && *p2 != 'H') {
      /* Both are the heavy atom */
      parent = atom1_parent;
      child  = atom2_parent;
    } else {
      /* One is the heavy atom and the other is hydrogen */
      if (*p1 != 'H') {
	parent = atom1_parent;
	child  = atom2_parent;
      } else {
	parent = atom2_parent;
	child  = atom1_parent;
      }
    }

    /* job for child group */
    for (j=child;j>=0;j=ad->ex[j].child_list) {
      ad->ex[j].flag &= ~ATOM_PARENT;
      ad->ex[j].flag |= ATOM_CHILD;
      ad->ex[j].parent = parent;

      if (ad->ex[j].child_list < 0) {
	/* connect parent's child_list to last of child's child_list */
	ad->ex[j].child_list = ad->ex[parent].child_list;
	break;
      }
    }

    /* job for parent group */
    ad->ex[parent].flag |= ATOM_PARENT;
    ad->ex[parent].child_list = child;
  }

  /* count group */
  rt->n_grp = 0;
  for (i=0;i<ad->natom;i++)
    if ((ad->ex[i].flag & ATOM_RATTLE) && (ad->ex[i].flag & ATOM_PARENT))
      rt->n_grp++;

  rt->grp = emalloc("RATTLE_setup_group", sizeof(RATTLE_GROUP)*rt->n_grp);
  
  n = 0;
  for (i=0;i<ad->natom;i++) {
    if ((ad->ex[i].flag & ATOM_RATTLE) && (ad->ex[i].flag & ATOM_PARENT)) {
      rt->grp[n].parent_atom = i;
      rt->grp[n].node_grp_n = n + 1;
      ad->ex[i].parent = n;
      rt->grp[n].w = ad->w[i];
      for (j=ad->ex[i].child_list;j>=0;j=ad->ex[j].child_list) {
	rt->grp[n].w += ad->w[j];
      }
      n++;
    }
  }
  if (rt->n_grp == 0)
    rt->node_grp_h = -1;
  else {
    rt->node_grp_h = 0;
    rt->grp[n-1].node_grp_n = -1;
  }
}
 

void RATTLE_degree_of_freedom(RATTLE *rt, BOND_DATA *bd, ATOM_DATA *ad,
			      int *degree_of_freedom_ex, int *n_rattle_group)
{
  int i, iex, jex, atom1, atom2;

  *n_rattle_group = rt->n_grp;
  
  for (iex=0;iex<MAX_EX_SYSTEM;iex++)
    degree_of_freedom_ex[iex]=0;

  for (i=0;i<ad->natom;i++) {
    if (ad->ex[i].flag & ATOM_RATTLE) {
      iex = ATOM_FLAG_EX_SYSTEM(ad->ex[i].flag);
      degree_of_freedom_ex[iex] += 3;
    }
  }
  
  for (i=0;i<bd->n_bond;i++) {
    if (bd->bonds[i].flag & RATTLE_FLAG) {
      atom1 = bd->bonds[i].atom1;
      atom2 = bd->bonds[i].atom2;
      iex = ATOM_FLAG_EX_SYSTEM(ad->ex[atom1].flag);
      jex = ATOM_FLAG_EX_SYSTEM(ad->ex[atom2].flag);
      if (iex!=jex) {
	lprintf("ERROR: atom1 %d and atom2 %d, which are connected by RATTLE bond, belong to different extended systems (%d,%d).\n",atom1+1,atom2+1,iex,jex);
	marble_exit(1);
      }
      degree_of_freedom_ex[iex]--;
    }
  }
}

void RATTLE_correct_virial(RATTLE *rt, ATOM_DATA *ad)
{
  int i,j;
  double dx,dy,dz;
  RATTLE_GROUP *grp;
  
  RATTLE_calc_rg(rt,ad);
  
  for (j=rt->node_grp_h; j>=0; j=rt->grp[j].node_grp_n) {
    grp=&(rt->grp[j]);
    for (i=grp->parent_atom;i>=0;i=ad->ex[i].child_list) {
      dx = ad->x[i].x - grp->rg.x;
      dy = ad->x[i].y - grp->rg.y;
      dz = ad->x[i].z - grp->rg.z;
      
      ad->virial[0] -= ad->f[i].x * dx;
      ad->virial[1] -= ad->f[i].y * dy;
      ad->virial[2] -= ad->f[i].z * dz;
      
      ad->virial[3] -= ad->f[i].x * dy;
      ad->virial[4] -= ad->f[i].x * dz;
      ad->virial[5] -= ad->f[i].y * dz;
      
      ad->virial[6] -= ad->f[i].y * dx;
      ad->virial[7] -= ad->f[i].z * dx;
      ad->virial[8] -= ad->f[i].z * dy;
    }
  }  
}
