/*
 * 
 * This source code is part of 
 *   MARBLE (MoleculAR simulation package for BiomoLEcules)
 * 
 * Written by Mitsunori Ikeguchi
 * Copyright (c) 2012 Yokohama City University
 *  
 * This program is free software; you can redistribute it and/or
 * modify it under the terms of the GNU General Public License
 * as published by the Free Software Foundation; either version 2
 * of the License, or (at your option) any later version.
 * 
 * This program is distributed in the hope that it will be useful,
 * but WITHOUT ANY WARRANTY; without even the implied warranty of
 * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 * GNU General Public License for more details.
 * 
 */

#include <stdio.h>
#include <math.h>
#include <stdlib.h>
#include <stdarg.h>

#include "md_system.h"
#include "parallel.h"

/* #define ALL_NBLIST_ON_ALL_PE */
#define DEBUG 0

#ifdef MPI_RDMD

void assign_tasks(MD_SYSTEM *sys)
{
  /* bonds */
  assign_tasks_atom(&sys->atom);
  assign_tasks_bond(&sys->bond);
  assign_tasks_angle(&sys->angle);
  assign_tasks_dihedral(&sys->dihed);
}

void calc_number_tasks(int n_data, int *start_task, int *end_task)
{
  int n_task, remainder;
  
  n_task = n_data / mpi.n_pe;
  remainder = n_data % mpi.n_pe;

  if (mpi.rank < remainder) {
    /* if I am (n_task + 1) */
    *start_task = (n_task + 1) * mpi.rank;
    *end_task = *start_task + n_task;
  } else {
    /* if I am (n_task) */
    *start_task = remainder + n_task * mpi.rank;
    *end_task = *start_task + n_task - 1;
  }
}

void assign_tasks_atom(ATOM_DATA *ad)
{
  int i_rank, n_task, remainder;
  
  calc_number_tasks(ad->natom, &(ad->start_task), &(ad->end_task));

  n_task = ad->natom / mpi.n_pe;
  remainder = ad->natom % mpi.n_pe;
  
  for (i_rank = 0; i_rank < mpi.n_pe; i_rank++) {
    if (i_rank < remainder) {
      /* if I am (n_task + 1) */
      ad->data_counts[i_rank] = (n_task + 1) * 3;
      ad->data_offsets[i_rank] = (n_task + 1) * i_rank * 3;
    } else {
      /* if I am (n_task) */
      ad->data_counts[i_rank] = n_task * 3;
      ad->data_offsets[i_rank] = (remainder + n_task * i_rank) * 3;
    }
#if DEBUG
    lprintf("data_counts[%d]=%d\n", i_rank, ad->data_counts[i_rank]);
    lprintf("data_offsets[%d]=%d\n", i_rank, ad->data_offsets[i_rank]);
#endif
  }
}

void assign_tasks_bond(BOND_DATA *bd)
{
  calc_number_tasks(bd->n_bond, &(bd->start_task), &(bd->end_task));
  
#if DEBUG  
  printf("b %d: beg %d end %d\n", mpi.rank, bd->start_task, bd->end_task);
  fflush(stdout);
#endif /* DEBUG */  
}

void assign_tasks_angle(ANGLE_DATA *ad)
{
  calc_number_tasks(ad->n_angle, &(ad->start_task), &(ad->end_task));
  
#if DEBUG  
  printf("a %d: beg %d end %d\n", mpi.rank, ad->start_task, ad->end_task);
  fflush(stdout);
#endif /* DEBUG */  
}

void assign_tasks_dihedral(DIHEDRAL_DATA *dd)
{
  calc_number_tasks(dd->n_dihedral, &(dd->start_task), &(dd->end_task));
  
#if DEBUG  
  printf("d %d: beg %d end %d\n", mpi.rank, dd->start_task, dd->end_task);
  fflush(stdout);
#endif /* DEBUG */  
}

void assign_tasks_nonbond_list(NONBOND_LIST *nl)
{
#ifdef ALL_NBLIST_ON_ALL_PE
  calc_number_tasks(nl->n_list, &(nl->start_task), &(nl->end_task));
#if DEBUG  
  printf("d %d: beg %d end %d\n", mpi.rank, nl->start_task, nl->end_task);
  fflush(stdout);
#endif /* DEBUG */
#else
  nl->start_task = 0;
  nl->end_task = nl->n_list - 1;
#endif /* MAKE_ALL_NONBOND_LIST */
}

void assign_tasks_linked_cell(LINKED_CELL *lc)
{
#ifdef ALL_NBLIST_ON_ALL_PE
  lc->start_task = 0;
  lc->end_task = lc->n_cell - 1;
#else
  calc_number_tasks(lc->n_cell, &(lc->start_task), &(lc->end_task));
#if DEBUG  
  printf("d %d: beg %d end %d\n", mpi.rank, lc->start_task, lc->end_task);
  fflush(stdout);
#endif /* DEBUG */  
#endif /* MAKE_ALL_NONBOND_LIST */
}

/**** reduce ..... ****/
void sync_neighbor(double *v, int count);
void MPI_force_test(MD_SYSTEM *sys, VEC *f1, VEC *f2);
/* #define MARBLE_ALL_REDUCE  */
void marble_all_reduce(double *src, double *target, int count);
     
void reduce_energy_force(MD_SYSTEM *sys)
{
  int i;
  double ene_tmp[N_ENE];
  VEC *f_tmp;

  dtime();
  /*
  MPI_Barrier(mpi.comm);
  add_dtime(&sys->time_comm3);
  */

#ifdef MARBLE_ALL_REDUCE
  marble_all_reduce(sys->ene, ene_tmp, N_ENE);
#else  
  MPI_Allreduce(sys->ene, ene_tmp, N_ENE, MPI_DOUBLE, MPI_SUM, mpi.comm);
#endif

  for (i=0;i<N_ENE;i++) sys->ene[i] = ene_tmp[i];

  add_dtime(&sys->time_comm1);

#ifdef MARBLE_ALL_REDUCE
  marble_all_reduce((double*)sys->atom.f, (double*)sys->atom.f_reduce,
		   sys->atom.natom*3);
#else
  MPI_Allreduce(sys->atom.f, sys->atom.f_reduce, sys->atom.natom*3, MPI_DOUBLE,
		MPI_SUM, mpi.comm);
#endif  
  f_tmp = sys->atom.f;
  sys->atom.f = sys->atom.f_reduce;
  sys->atom.f_reduce = f_tmp;

  add_dtime(&sys->time_comm2);

  MPI_force_test(sys, sys->atom.f, sys->atom.f_reduce);
  /*
  sync_neighbor(sys->ene,    N_ENE);
  sync_neighbor((double*) sys->atom.f, sys->atom.natom*3);
  */
}

/* this routine is used to avoid bug in DS20 Cluster. */
void sync_neighbor(double *v, int count)
{
  MPI_Status stat;
  
  if (mpi.rank % 2 == 0) {
    MPI_Send(v, count, MPI_DOUBLE, mpi.rank+1, 10, mpi.comm);
  } else {
    MPI_Recv(v, count, MPI_DOUBLE, mpi.rank-1, 10, mpi.comm, &stat);
  }
}

#ifdef MARBLE_ALL_REDUCE

/* This routine assumes the dual machines and duplex communication */
void marble_all_reduce(double *src, double *target, int count)
{
  MPI_Status stat;
  int n1, i, dest_rank;

  /* At first, data is got from neighbor. Assume that number of PE is even. */
  if ((mpi.rank & 1) == 0) {
    MPI_Recv(target, count, MPI_DOUBLE, mpi.rank+1, 10, mpi.comm, &stat);
    for (i=0;i<count;i++) {
      target[i] += src[i];
    }
  } else {
    MPI_Send(src, count, MPI_DOUBLE, mpi.rank-1, 10, mpi.comm);
  }
  
  /* hypercube algorithm */
  if ((mpi.rank & 1) == 0) {
    for (n1=2; n1<mpi.n_pe; n1=n1<<1) {
      dest_rank = mpi.rank ^ n1;   /* xor */
      if (dest_rank < mpi.n_pe) {
	MPI_Sendrecv(target, count, MPI_DOUBLE, dest_rank, 11,
		     src,    count, MPI_DOUBLE, dest_rank, 11,
		     mpi.comm, &stat);
	for (i=0;i<count;i++) {
	  target[i] += src[i];
	}
      }
    }
    MPI_Send(target, count, MPI_DOUBLE, mpi.rank+1, 10, mpi.comm);
  } else {
    MPI_Recv(target, count, MPI_DOUBLE, mpi.rank-1, 10, mpi.comm, &stat);
  }
  /*
  printf("result %d, %f\n", mpi.rank, target[0]);
  marble_exit(1);
  */
}
#endif  /* def MARBLE_ALL_REDUCE */

void MPI_force_test(MD_SYSTEM *sys, VEC *f1, VEC *f2)
{
  MPI_Status stat;
  int i, error = 0;
  
  if (mpi.rank % 2 == 1) {
    MPI_Send(f1, sys->atom.natom*3, MPI_DOUBLE, mpi.rank-1, 10, mpi.comm);
  } else {
    MPI_Recv(f2, sys->atom.natom*3, MPI_DOUBLE, mpi.rank+1, 10, mpi.comm,
	     &stat);
    for (i=0;i<sys->atom.natom;i++) {
      if (f1[i].x != f2[i].x ||
	  f1[i].y != f2[i].y ||
	  f1[i].z != f2[i].z) {
	error = 1;
	printf("%8.3f ps %5d(%.16e,%.16e,%.16e)\n         != %5d(%.16e,%.16e,%.16e)\n",
	       sys->current_time, 
	       mpi.rank,    f1[i].x,f1[i].y,f1[i].z,
	       i,           f2[i].x,f2[i].y,f2[i].z);
      }
    }
    if (mpi.rank == 0) {
      for (i=0;i<mpi.n_pe;i+=2) {
	if (i > 0)
	  MPI_Recv(&error, 1, MPI_INT, i, 12, mpi.comm, &stat);
	if (error) {
	  lprintf("ERROR OCCURED AT RANK %d or %d\n", i, i+1);
	  lflush();
	}
      }
    } else {
      MPI_Send(&error, 1, MPI_INT, 0, 12, mpi.comm);
    }
  }
}
		    
void X_all_gather(MD_SYSTEM *sys)
{
  int n_data;
  ATOM_DATA *ad;

  ad=&sys->atom;
  dtime();
  /*
  { int i;
    for (i=0;i<ad->natom;i++) {
      printf("%d:x[%d].x = %f\n", mpi.rank, i, ad->x[i].x);
      printf("%d:v[%d].x = %f\n", mpi.rank, i, ad->v[i].x);
    }
  }
  */

  MPI_Allgatherv(&ad->x[ad->start_task].x, ad->data_counts[mpi.rank],
		 MPI_DOUBLE,
		 ad->x, ad->data_counts, ad->data_offsets, MPI_DOUBLE,
		 mpi.comm);
  /*
  MPI_Allgatherv(&ad->v[ad->start_task].x, ad->data_counts[mpi.rank],
		 MPI_DOUBLE,
		 ad->v, ad->data_counts, ad->data_offsets, MPI_DOUBLE,
		 mpi.comm);
  */		 
  add_dtime(&sys->time_comm4);
  /*
  { int i;
    for (i=0;i<ad->natom;i++) {
      printf("%d:x[%d].x = %f\n", mpi.rank, i, ad->x[i].x);
      printf("%d:v[%d].x = %f\n", mpi.rank, i, ad->v[i].x);
    }
  }
  */
}

#endif  /* MPI_RDMD */

#ifdef MPI

void sync_xv(MD_SYSTEM *sys)
{
  ATOM_DATA *ad;
  RMOL_DATA *rd;
  ad=&sys->atom;
  rd=&sys->rigid_mol;


  MPI_Bcast(ad->x, ad->natom*3, MPI_DOUBLE, mpi.master_pe, mpi.comm);

  MPI_Bcast(ad->v, ad->natom*3, MPI_DOUBLE, mpi.master_pe, mpi.comm);

  RMOL_DATA_copy_rmolcrd_to_buf(rd);
  MPI_Bcast(rd->crd, rd->n_mol*13, MPI_DOUBLE, mpi.master_pe, mpi.comm);
  RMOL_DATA_copy_rmolcrd_from_buf(rd);
  RMOL_DATA_mol_to_room_all(rd, ad);

  if (sys->Ex_System_T_flag) {
    MPI_Bcast(sys->eta,   MAX_EX_SYSTEM*MAX_CHAIN_T, MPI_DOUBLE, mpi.master_pe, mpi.comm);
    MPI_Bcast(sys->eta_v, MAX_EX_SYSTEM*MAX_CHAIN_T, MPI_DOUBLE, mpi.master_pe, mpi.comm);
  }
  if (sys->Ex_System_P_flag==1) {
    MPI_Bcast(&sys->logv_v,   1, MPI_DOUBLE, mpi.master_pe, mpi.comm);
  } else if (sys->Ex_System_P_flag >= 2) {
    MPI_Bcast(sys->Vg,       9, MPI_DOUBLE, mpi.master_pe, mpi.comm);
  }
}

void plprintf(char *fmt, ...)
{
  int i;
  char buf[1000];
  MPI_Status status;
  va_list args;

  va_start(args, fmt);
  vsprintf(buf, fmt, args);

  if (mpi.rank == 0) {
    lprintf("rank[%2d]:%s",mpi.rank,buf);
    for (i=0;i<mpi.n_pe;i++) {
      if (i!=mpi.rank) {
	MPI_Recv(buf, 1000, MPI_CHAR, i, TAG_PLPRINTF, mpi.comm, &status);
	lprintf("rank[%2d]:%s",i,buf);
      }
    }
  } else {
    MPI_Send(buf, 1000, MPI_CHAR, 0, TAG_PLPRINTF, mpi.comm);
  }
}

typedef struct s_tr_buf {
  void *data;
  size_t size;
  int in_use;
  struct s_tr_buf *next;
} buf_t;

static buf_t *_buf_head = NULL;

void *get_buf(size_t size)
{
  buf_t *buf;
  char *func="buf_get";

  if (size == 0) return NULL;
  for (buf=_buf_head;buf!=NULL;buf=buf->next) {
    if (!buf->in_use && size <= buf->size) {
      buf->in_use = 1;
      return buf->data;
    }
  }
  buf=emalloc(func,sizeof(buf_t));
  buf->next=_buf_head;
  _buf_head=buf;
  buf->data=emalloc(func,size);
  buf->size=size;
  buf->in_use = 1;
  lprintf("BUF: Allocate %d.\n",size);
  return buf->data;
}

double *get_double_buf(size_t size)
{
  return get_buf(sizeof(double)*size);
}

int *get_int_buf(size_t size)
{
  return get_buf(sizeof(int)*size);
}

void free_buf(void *p)
{
  buf_t *buf;
  if (p==NULL) return;
  for (buf=_buf_head;buf!=NULL;buf=buf->next) {
    if (p==buf->data) {
      buf->in_use = 0;
      return;
    }
  }
  fprintf(stderr,"INTERNAL ERROR: Can't find the buffer for free.\n");
  marble_exit(1);
}

void ave_min_max(double val, double ave_min_max[3])
{
  MPI_Allreduce(&val, &ave_min_max[0], 1, MPI_DOUBLE, MPI_SUM, mpi.comm);
  ave_min_max[0] /= mpi.n_pe;
  MPI_Allreduce(&val, &ave_min_max[1], 1, MPI_DOUBLE, MPI_MIN, mpi.comm);
  MPI_Allreduce(&val, &ave_min_max[2], 1, MPI_DOUBLE, MPI_MAX, mpi.comm);
}

FILE *par_fopen(char *fname, char *mode)
{
  FILE *fp;
  int err = 0;
  if (mpi.master) {
    fp = fopen(fname, mode);
    if (fp == NULL) 
      err = 1;
  } else 
    fp = NULL;
  
  MPI_Bcast(&err, 1, MPI_INT, mpi.master_pe, mpi.comm);
  if (err)
    return NULL;
  return fp;
} 

char *par_fgets(char *buf, int size, FILE *fp)
{
  char *ret;
  if (mpi.master) {
    ret = fgets(buf, size, fp);
    if (ret == NULL) {
      buf[0] = (char) -1;
    }
  }
  
  MPI_Bcast(buf, size, MPI_CHAR, mpi.master_pe, mpi.comm);
  if (buf[0] == (char) -1) {
    return NULL;
  }
  return buf;
}

int par_fclose(FILE *fp)
{
  int ret;
  if (mpi.master) {
    ret = fclose(fp);
  }

  MPI_Bcast(&ret, 1, MPI_INT, mpi.master_pe, mpi.comm);
  return ret;
}

#else
static int dummy;   /* for Warning in compiling */
#endif  /* ifdef MPI */
