/*
 * Copyright (c) 2003-2005 RIKEN Japan, All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY RIKEN AND CONTRIBUTORS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL RIKEN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGE.
 */

/* $SATELLITE: satellite4/modules/bps/command/learn.cpp,v 1.5 2005/02/22 04:32:19 ninja Exp $ */

#ifdef HAVE_CONFIG_H
# include "config.h"
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#include "libbps.h"

/************************************************
 *																							*
 *	Back Propergation Simurator(BPS)						*
 *	      subroutine package										*
 *	  	Version 4.0															*
 *	  coded		in May.17 1989										*
 *	  coded by 	Y.Okamura												*
 *	  last modified	in Nov.15 1990							*
 *         modified by   K.Kuroda								*
 *																							*
 ************************************************
 *																							*
 *	filename setlmain.c 												*
 *	    BP learning controler										*
 *																							*
 ************************************************/

#ifdef __cplusplus
extern "C" {
#endif

/* #include "learn.h" */

#define	CHECKEND  ((bps_cont.SumOfErr>bps_cont.MinError) && (learn_cnt!=bps_cont.MaxLearnCount))

static int     learn_cnt, wgt_stor_cnt;
static int     err_stor_cnt, err_datapoint, display_cnt;
static int     buff_no, buff_write_flag;


/************************************************
  set learning
  input:
  lrn_mode : learning mode
  ************************************************/
double set_learn(int lrn_mode)
{
  int     unit, ptrn;
  double  error;

  error = 0.0;
  for (unit = 0; unit < bps_cont.NumOfCell[bps_cont.NumOfLayer-1]; bps_cont.OutCellErr[unit++] = 0.0);

  workspace_initialize();
  for (ptrn = 0; ptrn < bps_cont.NumOfPtrn; ptrn++) {
    error += forward_learn(bps_cont.InputData[ptrn], bps_cont.TeachData[ptrn]);
    if (lrn_mode)
      backward_learn();
  }
  return (error);
}


/*************************************************
  store weight
  *************************************************/
int
store_weight()
{
  Header   head;
  int      check;
  bps_ilin_t  *link_pt;

  /* StoreWeight2(WgtHistoryFile, wgt_stor_cnt ); */
  /* T.Hayasaka 1994 4/26 */
  StoreWeight2(bps_cont.WgtHistoryFile, wgt_stor_cnt - 1);
  if (bps_cont.WgtStorMode == BPS_STOREMODE_APPEND)
    wgt_stor_cnt++;

  check = LoadHeader(bps_cont.WgtHistoryFile, &head);
  if (check == -1)
		return 102; /* Header Read Error */

  link_pt = inter_ilin(bps_cont.BPNet[0][1].CellNode, bps_cont.BPNet[1][1].CellNode);
	if (link_pt == NULL)
		return 120; /* Can't Find Link */

  sprintf(head.comment, "%s%f", "learning rate = ", link_pt->CoefLearn);

  StoreHeader(bps_cont.WgtHistoryFile, &head);

	return 0;
}


/*************************************************
  store error
  *************************************************/
void
store_error()
{
  if (bps_cont.ErrStorDirection == BPS_STOREDIR_RECORD) {
    StoreErrRecord(bps_cont.ErrHistoryFile, err_stor_cnt);
  } else
    WriteErrDataPoint(bps_cont.ErrHistoryFile, err_stor_cnt, err_datapoint++, bps_cont.SumOfErr);
  if (bps_cont.ErrStorMode == BPS_STOREMODE_APPEND)
    err_stor_cnt++;
}


/*************************************************
  display iteration, error, comment
  *************************************************/
int
display()
{
  int    index[10];

  printf("\n");
  printf("### ITERATION        ### = %d\n",   learn_cnt);
  printf("### SQUARE'S ERROR   ### = %14e\n", bps_cont.SumOfErr);
  printf("### DIFFERENCE       ### = %14e\n",
	 bps_cont.SumOfErr - bps_cont.buff_err[display_cnt-2]);
/*	 SumOfErr - buff_err[learn_cnt-2]); */
  printf("### COMMENT          ### = %s\n",   bps_cont.Comment);

  if (buff_write_flag) {
    index[0] = display_cnt;
    if(WriteBuffer(buff_no, 1, index, bps_cont.buff_err) == -1)
			return 3; /* Buffer Write Error */
  }

	return 0;
}


/************************************************
  steep method
  ************************************************/
int
Steep_method()
{
  int     ptrn, unit;
  double  ptrn_err = 0.0;

  while (CHECKEND) {
    learn_cnt++;

    if (bps_cont.LearnMode == BPS_LEARNMODE_SET)
      steep();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < bps_cont.NumOfPtrn; ptrn++) {
	for (unit = 0; unit < bps_cont.NumOfCell[bps_cont.NumOfLayer-1];
	     bps_cont.OutCellErr[unit++] = 0.0);

	workspace_initialize();
	ptrn_err += forward_learn(bps_cont.InputData[ptrn], bps_cont.TeachData[ptrn]);
	backward_learn();
	steep();
      }
    }

    if ((learn_cnt % bps_cont.WgtStorInterval) == 0)
      store_weight();

    bps_cont.SumOfErr = (bps_cont.LearnMode == BPS_LEARNMODE_SET) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % bps_cont.ErrStorInterval) == 0)
      store_error();

    bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;
    if ((learn_cnt % bps_cont.DisplayInterval) == 0)
      display();
  }

	return 0;
}


/************************************************
  momentum method ( set_learn )
  ************************************************/
int
setMomentum()
{
  while (CHECKEND) {
    learn_cnt++;

    momentum1();

    if ((learn_cnt % bps_cont.WgtStorInterval) == 0)
      store_weight();

    bps_cont.SumOfErr = set_learn(1);

    if ((learn_cnt % bps_cont.ErrStorInterval) == 0)
      store_error();

    bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;
    if ((learn_cnt % bps_cont.DisplayInterval) == 0)
      display();
  }

	return 0;
}


/************************************************
  momentum method ( pattern_learn )
  ************************************************/
int
patternMomentum()
{
  int  ptrn, unit;

  while (CHECKEND) {
    learn_cnt++;

    bps_cont.SumOfErr = 0.0;
    for (ptrn = 0; ptrn < bps_cont.NumOfPtrn; ptrn++) {
      for (unit = 0; unit < bps_cont.NumOfCell[bps_cont.NumOfLayer-1];
	   bps_cont.OutCellErr[unit++] = 0.0);

      workspace_initialize();
      bps_cont.SumOfErr += forward_learn(bps_cont.InputData[ptrn], bps_cont.TeachData[ptrn]);
      backward_learn();
      momentum1();
    }

    if ((learn_cnt % bps_cont.WgtStorInterval) == 0)
      store_weight();

    if ((learn_cnt % bps_cont.ErrStorInterval) == 0)
      store_error();

    bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;
    if ((learn_cnt % bps_cont.DisplayInterval) == 0)
      display();
  }

	return 0;
}


/************************************************
  Vogl method
  ************************************************/
int
Vogl_method()
{
  int      ptrn, unit, lay;
  double   ptrn_err = 0.0;
  bps_ilin_t  *link_pt;


  for (lay = 0; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
	link_pt->CoefLearn = bps_cont.LearnRate;
	link_pt = Getinfwdlist(link_pt);
      }
    }
  }

  while (CHECKEND) {
    learn_cnt++;

    if (bps_cont.LearnMode == BPS_LEARNMODE_SET)
      vogl();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < bps_cont.NumOfPtrn; ptrn++) {
	for (unit = 0; unit < bps_cont.NumOfCell[bps_cont.NumOfLayer-1];
	     bps_cont.OutCellErr[unit++] = 0.0);

	workspace_initialize();
	ptrn_err += forward_learn(bps_cont.InputData[ptrn], bps_cont.TeachData[ptrn]);
	backward_learn();
	vogl();
      }
    }

    if ((learn_cnt % bps_cont.WgtStorInterval) == 0)
      store_weight();

    bps_cont.SumOfErr = (bps_cont.LearnMode == BPS_LEARNMODE_SET) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % bps_cont.ErrStorInterval) == 0)
      store_error();

    bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;
    if ((learn_cnt % bps_cont.DisplayInterval) == 0)
      display();
  }

	return 0;
}


/************************************************
  Jacob's method
  ************************************************/
int
Jacob_method()
{
  int             ptrn, unit;
  double          ptrn_err = 0.0;

  while (CHECKEND) {
    learn_cnt++;

    if (bps_cont.LearnMode == BPS_LEARNMODE_SET)
      jacobs();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < bps_cont.NumOfPtrn; ptrn++) {
	for (unit = 0; unit < bps_cont.NumOfCell[bps_cont.NumOfLayer-1];
	     bps_cont.OutCellErr[unit++] = 0.0);

	workspace_initialize();
	ptrn_err += forward_learn(bps_cont.InputData[ptrn], bps_cont.TeachData[ptrn]);
	backward_learn();
	jacobs();
      }
    }

    if ((learn_cnt % bps_cont.WgtStorInterval) == 0)
      store_weight();

    bps_cont.SumOfErr = (bps_cont.LearnMode == BPS_LEARNMODE_SET) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % bps_cont.ErrStorInterval) == 0)
      store_error();

    bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;
    if ((learn_cnt % bps_cont.DisplayInterval) == 0)
      display();
  }

	return 0;
}


/************************************************
  momentum Vogl's coefficient method
  ************************************************/
int
Vogl_coef_method()
{
  int     ptrn, unit;
  double  ptrn_err = 0.0;

  vgl2_coe();

  while (CHECKEND) {
    learn_cnt++;

    if (bps_cont.LearnMode == BPS_LEARNMODE_SET)
      momentum2();
    else {
      ptrn_err = 0.0;
      for (ptrn = 0; ptrn < bps_cont.NumOfPtrn; ptrn++) {
	for (unit=0; unit < bps_cont.NumOfCell[bps_cont.NumOfLayer-1]; bps_cont.OutCellErr[unit++]=0.0);

	workspace_initialize();
	ptrn_err += forward_learn(bps_cont.InputData[ptrn], bps_cont.TeachData[ptrn]);
	backward_learn();
	momentum2();
      }
    }

    if ((learn_cnt % bps_cont.WgtStorInterval) == 0)
      store_weight();

    bps_cont.SumOfErr = (bps_cont.LearnMode == BPS_LEARNMODE_SET) ? set_learn(1) : ptrn_err;

    if ((learn_cnt % bps_cont.ErrStorInterval) == 0)
      store_error();

    bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;
    if ((learn_cnt % bps_cont.DisplayInterval) == 0)
      display();
  }

	return 0;
}


/************************************************
  An acceralated learning method to reduce
  the oscillation of weight for neural networks
  ************************************************/
int
Ochi_method()
{
  int      unit, lay;
  bps_ilin_t  *link_pt;

  for (lay = 0; lay < bps_cont.NumOfLayer; lay++)
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      link_pt = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (link_pt != NULL) {
	link_pt->CoefLearn = bps_cont.LearnRate;
	link_pt->wgtworkold = -link_pt->WgtWork;
	link_pt = Getinfwdlist(link_pt);
      }
    }

  while (CHECKEND) {
    learn_cnt++;
    Ochi();

    if ((learn_cnt % bps_cont.WgtStorInterval) == 0)
      store_weight();

    bps_cont.SumOfErr = set_learn(1);

    if ((learn_cnt % bps_cont.ErrStorInterval) == 0)
      store_error();

    bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;

    if ((learn_cnt % bps_cont.DisplayInterval) == 0)
      display();
  }

	return 0;
}


/************************************************
  learning main routine
  ************************************************/
DLLEXPORT int mod_bps_learn()
{
  int   idx[10];

	rebps();

  /* SYSTEM INITIALIZE */
  GetStructureParameters();
  PrintStructureParameters();
  GetLearningParameters();
  PrintLearningParameters();

  buff_no = (int)GetBufferID(0);
  if (buff_no < 0)
		return 17; /* Illigal Buffer No. */

  buff_write_flag = (buff_no != 0) ? 1 : 0;
  learn_cnt       = 0;
  display_cnt     = 0;

  system_initialize(); /*  ErrBuffer Υ꤬ݤEE*/
  printf("system_initialize is OK\n");

  /* MAKE NETWORK */
  MakeNetwork();

  /*    ReadWeight2(LrnInitWgtFile, LAST);  */
  ReadWeight2(bps_cont.LrnInitWgtFile, BPS_LAST + 1);

  /*
   * CreateFile2( WgtHistoryFile, "weight history", WgtStorMode );
   * CreateFile2( ErrHistoryFile, "error history", ErrStorMode );
   */
  CreateFile3(bps_cont.WgtHistoryFile, "weight history", bps_cont.WgtStorMode);
  CreateFile3(bps_cont.ErrHistoryFile, "error history",  bps_cont.ErrStorMode);

  wgt_stor_cnt = NextWgtHistory(bps_cont.WgtHistoryFile);

  /****** Modified by dora 1995/5/30 ******/
  printf("wgt_stor_cnt   = %d\n", wgt_stor_cnt);
  if (wgt_stor_cnt == 0)
    wgt_stor_cnt = 1;
  /****** Modified by dora 1995/5/30 ******/

#if 0
  if (ErrStorDirection == RECORD)
    err_stor_cnt = NextErrHistory(ErrHistoryFile);
  else {
    err_stor_cnt = GetNumOfRecord(ErrHistoryFile);

    printf("ErrHistoryFile = %s\n", ErrHistoryFile);
    printf("err_stor_cnt   = %d\n", err_stor_cnt);

    if (err_stor_cnt == 0) {
      err_stor_cnt  = 1;
      err_datapoint = 0;
    } else {
      err_datapoint = StoreOldError(ErrHistoryFile, err_stor_cnt);
      if (err_datapoint == 0)
	err_stor_cnt++;
    }
  }
#else
  /* modified by take */
  if ( bps_cont.ErrStorMode == BPS_STOREMODE_APPEND ) {
    if (bps_cont.ErrStorDirection == BPS_STOREDIR_RECORD)
      err_stor_cnt  = NextErrHistory(bps_cont.ErrHistoryFile);
    else {
      err_stor_cnt  = GetNumOfRecord(bps_cont.ErrHistoryFile) +1;
      err_datapoint = StoreOldError(bps_cont.ErrHistoryFile, err_stor_cnt);
    }
    if (err_datapoint == 0)
      err_stor_cnt = 1;
  } else {
    err_datapoint = 0;
    err_stor_cnt  = 1;
  }
  printf("ErrHistoryFile = %s\n", bps_cont.ErrHistoryFile);
  printf("err_stor_cnt   = %d\n", err_stor_cnt);
  /* modified by take */
#endif

  /* LEARNING */

  bps_cont.SumOfErr = set_learn(1);
  if ((err_stor_cnt  == 1) &&
      (err_datapoint == 0))
    store_error();

  bps_cont.buff_err[display_cnt++] = (double)bps_cont.SumOfErr;

  printf("\n");
  printf("### ITERATION        ### = %d\n",   learn_cnt);
  printf("### SQUARE'S ERROR   ### = %14e\n", bps_cont.SumOfErr);
  printf("### DIFFERENCE       ### = %14e\n", bps_cont.SumOfErr);
  printf("### COMMENT          ### = %s\n",   bps_cont.Comment);

  if (buff_write_flag){
    idx[0] = display_cnt;
    if(WriteBuffer(buff_no, 1, idx, bps_cont.buff_err) == -1)
      return 3; /* Buffer Write Error */
  }

  switch (bps_cont.LearnAlgo) {
  case BPS_LEARNALGO_STEEP:      Steep_method();     break;
  case BPS_LEARNALGO_MOMENTUM:   
    if (bps_cont.LearnMode == BPS_LEARNMODE_SET)
      setMomentum();
    else
      patternMomentum();
    break;
  case BPS_LEARNALGO_VOGL:       Vogl_method();      break;
  case BPS_LEARNALGO_JACOB:      Jacob_method();     break;
  case BPS_LEARNALGO_MOMENTUM2:  Vogl_coef_method(); break;
  case BPS_LEARNALGO_OCHI:       Ochi_method();      break;
  }

  if ((learn_cnt % bps_cont.ErrStorInterval) != 0) store_error();
  if ((learn_cnt % bps_cont.WgtStorInterval) != 0) store_weight();
  if ((learn_cnt % bps_cont.DisplayInterval) != 0) display();

  printf("\n\t*** Learning is done ! ***\n");

  BreakNetwork();
  //system_end();

  wrbps();
  return 0;
}

#ifdef __cplusplus
}
#endif