/*
 * 
 * This source code is part of 
 *   MARBLE (MoleculAR simulation package for BiomoLEcules)
 * 
 * Written by Mitsunori Ikeguchi
 * Copyright (c) 2012 Yokohama City University
 *  
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 */

#include <stdio.h>
#include <stdlib.h>
#include <math.h>

#include "misc.h"
#include "md_system.h"
#include "min.h"

#ifdef MPI_RDMD
#include "parallel.h"
#endif
#ifdef MPI_SDMD
#include "parallel.h"
#include "sdmd.h"
#endif

#define LINE_MIN_MAX_STEP  10
void MIN_conjugate_gradient2(MD_SYSTEM *, int, double,double,double);

typedef struct {
  double x, u, ff, vf;
} MIN_POINT;


void MIN_main(MD_SYSTEM *sys, 
	      int sd_step, double sd_init_step, double sd_largest_step,
	      double rms,  int print_out_step,
	      int cg_step, double cg_init_step, double cg_largest_step,
	      double cg_line_min_tol)
{
  if (sys->rigid_mol_flag) {
    lprintf("Rigid body can't be used in minimization.\n");
    marble_exit(1);
  }

  MIN_steepest(sys, sd_step, sd_init_step, sd_largest_step, rms,
	       print_out_step);

  if (cg_step > 0) {

    if (sys->rattle.flag) {
      lprintf("Rattle can't be used in conjugate gradient minimization.\n");
      marble_exit(1);
    }
    
    MIN_conjugate_gradient(sys, cg_step, cg_init_step, cg_largest_step,
			   cg_line_min_tol, rms, print_out_step);
    
  }

#ifdef MPI_SDMD
  SDMD_gather_x(&sys->linked_cell, &sys->atom);
#endif  
}

void MIN_print_status(MD_SYSTEM *sys, int step, double dstep, double rms)
{
#ifdef MPI_SDMD
  if (!mpi.master) return;
#endif  
  lprintf("STEP  %8d ENERGY %10.2f STEP_LEN  %.2e GRAD  %.2e\n",
	  step, sys->potential,dstep, rms);
  MD_SYSTEM_print_energy(sys);
  lprintf_bar();
  lflush();
}

void MIN_steepest(MD_SYSTEM *sys, int max_step, double dstep_init,
		  double dstep_largest, double rms_cri, int print_out_step)
{
  int step, i;
  double dstep, flen, e_prev, ene, width, rms;
  ATOM_DATA *ad;

  ad = &sys->atom;
  dstep = dstep_init;
  

  lprintf("Steepest Decent Minimization:\n");
  lprintf("  Maximum step = %d\n", max_step);
  lprintf("  Initial step length = %f\n", dstep_init);
  lprintf("  Largest step length = %f\n", dstep_largest);
  lprintf("  Convergence criteria of grad vector = %.2e [kcal/mol/angstrom]\n\n", rms_cri);

  if (max_step < 0) {
    lprintf("  ERROR: Maximum step number must be positive or zero.\n");
    marble_exit(1);
  }
  if (dstep_init <= 0) {
    lprintf("  ERROR: Initial step length must be positive.\n");
    marble_exit(1);
  }
  if (dstep_largest <= 0.0) {
    lprintf("  ERROR: Largest step length must be positive.\n");
    marble_exit(1);
  }
  if (rms_cri <= 0.0) {
    lprintf("  ERROR: Convergence criteria must be positive.\n");
    marble_exit(1);
  }

#ifdef MPI_SDMD
  SDMD_setup(sys);
#else  
  VERLET_setup(sys);
#endif  

  flen = MIN_flen(ad);
  rms = flen/sqrt(ad->natom);
  lprintf("\n"); lprintf_bar();
  MIN_print_status(sys,0, dstep, rms);
  e_prev = sys->potential;

  for (step=0; step < max_step; step++) {
    flen = MIN_flen(ad);
    width = dstep / flen;
    rms = flen/sqrt(ad->natom);
    if (rms < rms_cri) break;           /* convergence */

    if (sys->rattle.flag)
      RATTLE_backup_x(&sys->atom);

#ifdef MPI_SDMD
    for (i = ad->node_fatom_h; i>=0; i=ad->node_fatom_n[i]) {
#else      
    for (i = 0; i < sys->atom.natom; i++) {
#endif      
      ad->x[i].x += width * ad->f[i].x;
      ad->x[i].y += width * ad->f[i].y;
      ad->x[i].z += width * ad->f[i].z;
    }
    
    if (sys->rattle.flag) {
      dtime();
      RATTLE_a(&sys->rattle, &sys->bond, &sys->atom, sys->dt);
      add_dtime(&sys->time_rattle);
    }

#ifdef MPI_SDMD
    if (SDMD_check_migration(sys,step)) {
      SDMD_migration(sys, step, sys->dt_ps);
    } else {
      dtime();
      SDMD_dist_x_xyz(&sys->linked_cell, &sys->atom);
      if (sys->linked_cell.tr_dist_xf_flag)
	SDMD_dist_x(&sys->linked_cell, &sys->atom);
      add_dtime(&sys->linked_cell.time[SDMD_TIME_COMM_X]);
    }
    SDMD_calc_force(sys);
#else    
    if (NONBOND_LIST_check_update(&sys->nonbond, &sys->atom,
				  &sys->boundary, step+1)) {
      MD_SYSTEM_make_nonbond_list(sys);
    }
    MD_SYSTEM_calc_force(sys);
#endif
    
    if ((print_out_step > 0 &&
	(step+1) % print_out_step == 0) ||
	step == max_step - 1) {
      MIN_print_status(sys,step+1,dstep,rms);
    }
    
    ene = sys->potential;
    if (ene < e_prev) {
      dstep *= 2.0;
    } else {
      dstep *= 0.5;
    }
    if (dstep < 1.0e-3)        dstep = 1.0e-3;
    if (dstep > dstep_largest) dstep = dstep_largest;
    e_prev = ene;
  }
}

double MIN_flen(ATOM_DATA *ad)
{
  int i;
  double prod, prod2;

  prod = 0.0;
#ifdef MPI_SDMD
  for (i=ad->node_fatom_h;i>=0;i=ad->node_fatom_n[i]) {
#else    
  for (i=0;i<ad->natom;i++) {
#endif    
    prod += ad->f[i].x*ad->f[i].x
           +ad->f[i].y*ad->f[i].y
           +ad->f[i].z*ad->f[i].z;
  }
#ifdef MPI_SDMD
  MPI_Allreduce(&prod, &prod2, 1, MPI_DOUBLE, MPI_SUM, mpi.comm);
  return sqrt(prod2);
#else  
  return sqrt(prod);
#endif  
}

/* new_v = c*v + f */
void MIN_set_direction(MD_SYSTEM *sys, double c, double *vv, double *vf)
{
  int i;
  double tmp1[2],tmp2[2];
  ATOM_DATA *ad;
  
  ad = &sys->atom;
  *vv = 0.0;
  *vf = 0.0;
  for (i = ad->node_fatom_h; i>=0; i=ad->node_fatom_n[i]) {
    ad->v[i].x = ad->v[i].x * c + ad->f[i].x;
    ad->v[i].y = ad->v[i].y * c + ad->f[i].y;
    ad->v[i].z = ad->v[i].z * c + ad->f[i].z;
    *vv += ad->v[i].x * ad->v[i].x + ad->v[i].y * ad->v[i].y + ad->v[i].z * ad->v[i].z;
    *vf += ad->v[i].x * ad->f[i].x + ad->v[i].y * ad->f[i].y + ad->v[i].z * ad->f[i].z;
  }

  /*
  if (sys->bond.rattle_flag) {
    rattle_b(&sys->bond, ad);
  }
  */
  

#ifdef MPI_SDMD
  tmp1[0] = *vv;
  tmp1[1] = *vf;
  MPI_Allreduce(tmp1, tmp2, 2, MPI_DOUBLE, MPI_SUM, mpi.comm);
  *vv = tmp2[0];
  *vf = tmp2[1];
#endif  
  
  if (sys->rigid_mol_flag) {
    RMOL_DATA_set_direction(&sys->rigid_mol, ad, c);
  }
}

void MIN_calc_energy(MD_SYSTEM *sys, double x, MIN_POINT *mp)
{
  int i;
  double tmp1[2],tmp2[2];
  ATOM_DATA *ad;

  ad = &sys->atom;

  /*
  if (sys->bond.rattle_flag)
    rattle_backup_x(ad);
  */

#ifdef MPI_SDMD
  SDMD_integrate_coord(sys, x - mp->x);
  
  SDMD_migration(sys, 0, sys->dt_ps);
  
  SDMD_calc_force(sys);
#else  
  VERLET_integrate_coord(sys, x - mp->x);
  
  MD_SYSTEM_make_nonbond_list(sys);
  
  MD_SYSTEM_calc_force(sys);
#endif

  /*
  if (sys->bond.rattle_flag) {
    rattle_a(&sys->bond, ad, x - mp->x);
  }
  */
  
  mp->u = sys->potential;
  mp->x = x;

  mp->ff = mp->vf = 0.0;

  for (i = ad->node_fatom_h; i>=0; i=ad->node_fatom_n[i]) {
    mp->ff += ad->f[i].x * ad->f[i].x;
    mp->ff += ad->f[i].y * ad->f[i].y;
    mp->ff += ad->f[i].z * ad->f[i].z;
    mp->vf += ad->v[i].x * ad->f[i].x;
    mp->vf += ad->v[i].y * ad->f[i].y;
    mp->vf += ad->v[i].z * ad->f[i].z;
  }
  
#ifdef MPI_SDMD
  tmp1[0] = mp->ff;
  tmp1[1] = mp->vf;
  MPI_Allreduce(tmp1, tmp2, 2, MPI_DOUBLE, MPI_SUM, mpi.comm);
  mp->ff = tmp2[0];
  mp->vf = tmp2[1];
#endif  
  /*
  lprintf("Step %5d  Energy %.2f  FF %.2e  VF %.2e\n", step++, mp->u, mp->ff, mp->vf);
  */
}

void MIN_print_cg(MD_SYSTEM *sys, int step, double rms)
{
#ifdef MPI_SDMD
  if (!mpi.master) return;
#endif  
  lprintf("CG_STEP %11d  ENERGY%13.4f  GRAD %14.2e\n",
	  step, sys->potential, rms);
  MD_SYSTEM_print_energy(sys);
  lprintf_bar();
  lflush();
}

void MIN_conjugate_gradient(MD_SYSTEM *sys, int max_step,
			    double init_step, double largest_step,
			    double line_min_tol, double rms_cri, int print_out_step)
{
  int i, step;
  double x, ff, vv, min_vf, golden_ratio, c, min_x, max_x, dx, rms;
  MIN_POINT mp, mp1, mp2, mp3;   /*  mp1.x < mp2.x < mp3.x */

  lprintf("\n");
  lprintf("Conjugate Gradient Minimization:\n");
  lprintf("  Maximum step = %d\n", max_step);
  lprintf("  Initial step length = %e\n", init_step);
  lprintf("  Largest step length = %e\n", largest_step);
  lprintf("  Line minimization tolerance = %e\n", line_min_tol);
  lprintf_bar();

  golden_ratio = 0.5 * ( sqrt(5.0) - 1.0 );
  mp.x = 0.0;
  MIN_calc_energy(sys, 0.0, &mp);
  ff = mp.ff;
  c = 0.0;
  MIN_set_direction(sys, c, &vv, &min_vf);
  
  MIN_print_cg(sys, 0, sqrt(ff/sys->atom.natom));
  
  for (step=0;step<max_step;step++) {
    /* line minimization */

    if (init_step > largest_step) init_step = largest_step;
    x  = init_step * sqrt(ff/vv);
    mp2 = mp;
    MIN_calc_energy(sys, x, &mp);
    min_vf = mp.vf;
    mp1 = mp;

#if 0      
    if (mp.u >= mp2.u) {
      lprintf("Kitazo--!\n");
      i = 0;
      do {
	if (c != 0.0 && ++i >= 10) {
	  /* Reset_direction */
	  lprintf("Resetting direction at step %d\n", step+1);
	  MIN_calc_energy(sys, 0.0, &mp);
	  c = 0.0;
	  MIN_set_direction(sys, c, &vv, &min_vf);
	  goto top;
	}
	x *= 0.5;  init_step *= 0.5;
	mp3 = mp;
	MIN_calc_energy(sys, x, &mp);
      } while (mp.u > mp2.u);
      mp2 = mp;
    } else {
#endif      
      while (mp.u < mp2.u) {
	x *= 2.0;  init_step *= 2.0;
	mp1 = mp2;
	mp2 = mp;
	MIN_calc_energy(sys, x, &mp);
      }
      mp3 = mp;
      init_step *= 0.25;
#if 0      
    }
#endif    

    /* lprintf("%.5f %.5f %.5f ", mp1.u, mp2.u, mp3.u); */
    
    for (i=0; i<LINE_MIN_MAX_STEP; i++) {
      if (fabs(mp.vf/min_vf) < line_min_tol) break;

      if (0) {
	if ((mp2.x - mp1.x) > (mp3.x - mp2.x)) {   /* deviding between mp1 and mp2 */
	  x = (1.0 - golden_ratio) * mp1.x + golden_ratio * mp2.x;
	  MIN_calc_energy(sys, x, &mp);
	
	  if (mp.u <= mp2.u) {
	    mp3 = mp2;  mp2 = mp;
	  } else {
	    mp1 = mp;
	  }
	} else {                                   /* deviding between mp2 and mp3 */
	  x = (1.0 - golden_ratio) * mp3.x + golden_ratio * mp2.x;
	  MIN_calc_energy(sys, x, &mp);
	  if (mp.u <= mp2.u) {
	    mp1 = mp2;  mp2 = mp;
	  } else {
	    mp3 = mp;
	  }
	}
      } else {
	if (mp2.vf < 0.0) {                       /* deviding between mp1 and mp2 */
	  dx = mp1.x - mp2.x;
	  x = mp2.x + (mp2.vf*dx*dx)/(2.0*(mp1.u - mp2.u + mp2.vf * dx));
	  min_x = 0.9 * mp1.x + 0.1 * mp2.x;
	  max_x = 0.1 * mp1.x + 0.9 * mp2.x;
	  if (x < min_x) x = min_x;
	  if (x > max_x) x = max_x;
	  
	  MIN_calc_energy(sys, x, &mp);
	  
	  if (mp.u <= mp2.u) {
	    mp3 = mp2;  mp2 = mp;
	  } else {
	    mp1 = mp;
	  }
	} else {                                  /* deviding between mp2 and mp3 */
	  dx = mp3.x - mp2.x;
	  x = mp2.x + (mp2.vf*dx*dx)/(2.0*(mp3.u - mp2.u + mp2.vf * dx));
	  min_x = 0.9 * mp2.x + 0.1 * mp3.x;
	  max_x = 0.1 * mp2.x + 0.9 * mp3.x;
	  if (x < min_x) x = min_x;
	  if (x > max_x) x = max_x;
	  
	  MIN_calc_energy(sys, x, &mp);
	  if (mp.u <= mp2.u) {
	    mp1 = mp2;  mp2 = mp;
	  } else {
	    mp3 = mp;
	  }
	}
      }
    }  /* end of line_min */

    /*
#ifdef MPI_SDMD    
    SDMD_migration(sys, step, sys->dt_ps);
#else    
    MD_SYSTEM_make_nonbond_list(sys);
#endif
    */
    mp.x = 0.0;
    /*
    MIN_calc_energy(sys, 0.0, &mp);
    */
    c   = mp.ff/ff;
    ff  = mp.ff;

    if (c*c*vv > 100.0*ff) {
      lprintf("Resetting direction at step %d\n", step+1);
      c = 0.0;
    }

    /* new_v = c * old_v + f */
    MIN_set_direction(sys, c, &vv, &min_vf);
    rms = sqrt(mp.ff/sys->atom.natom);
    if (rms < rms_cri) {
      lprintf("Converged!\n");
      MIN_print_cg(sys, step+1, rms);
      return;
    }
    /* lprintf("%.5f %.5f %.5f \n", c, ff, min_vf); */

    if ((step+1) % print_out_step == 0)
      MIN_print_cg(sys, step+1, rms);
  }
}

