#include "parblock.h"

//#define DEBUG

#ifdef DEBUG
void print_mat(double* a, int size)
{
  int i,j;
  for(i = 0; i < size; i++){
    for(j = 0; j < size; j++)
      printf("%.0f ",a[i*size+j]);
    printf("\n");
  }
}
#endif

int main(int argc, char* argv[])
{
  if(argc < 2) {
    printf("spmdmatmul size\n");
    exit(0);
  }

  int size = atoi(argv[1]);
  if(size < 2) return 0;

  pards_init(16,3*size*size*sizeof(double)+100000);

  double *a = (double*)pards_shmalloc(size*size*sizeof(double));
  double *b = (double*)pards_shmalloc(size*size*sizeof(double));
  double *c = (double*)pards_shmalloc(size*size*sizeof(double));

  int i,j,k;

  // initialize
  srand(1);
  for(i = 0; i < size; i++)
    for(j = 0; j < size; j++){
      a[i*size+j] = rand() % 10;
      b[i*size+j] = rand() % 10;
    }

  struct timeval t1, t2;
  struct timezone tz;

  printf("serial version\n");
  gettimeofday(&t1,&tz);
  double sum;
  for(i = 0; i < size; i++){
    for(j = 0; j < size; j++){
      sum = 0;
      for(k = 0; k < size; k++){
	sum += a[i*size+k] * b[k*size+j];
      }
      c[i*size+j] = sum;
    }
  }
  gettimeofday(&t2,&tz);

#ifdef DEBUG
  printf("a = \n"); print_mat(a,size);
  printf("b = \n"); print_mat(b,size);
  printf("c = \n"); print_mat(c,size);
#endif

  printf("elapsed time = %f sec\n",
	 t2.tv_sec-t1.tv_sec + (t2.tv_usec - t1.tv_usec)/1000000.0);

  printf("parallel version\n");
  gettimeofday(&t1,&tz);

  int nproc = pards_get_nprocs();
  PBInfo* pbi = pards_begin_parallel(nproc);
  int pno = pbi->getpno();

  int part_size = 1 + (size - 1)/ nproc;
  int start = part_size * pno;
  int end = part_size * (pno+1);
  if(end > size) end = size;

  for(i = start; i < end; i++){
    for(j = 0; j < size; j++){
      sum = 0;
      for(k = 0; k < size; k++){
	sum += a[i*size+k] * b[k*size+j];
      }
      c[i*size+j] = sum;
    }
  }
  pards_end_parallel(pbi);
  gettimeofday(&t2,&tz);

#ifdef DEBUG
  printf("a = \n"); print_mat(a,size);
  printf("b = \n"); print_mat(b,size);
  printf("c = \n"); print_mat(c,size);
#endif
  
  printf("elapsed time = %f sec\n",
	 t2.tv_sec-t1.tv_sec + (t2.tv_usec - t1.tv_usec)/1000000.0);

  pards_finalize();
}
