/*
 * Copyright (c) 2003-2005 RIKEN Japan, All rights reserved.
 *
 * Redistribution and use in source and binary forms, with or without
 * modification, are permitted provided that the following conditions
 * are met:
 *
 * 1. Redistributions of source code must retain the above copyright
 *    notice, this list of conditions and the following disclaimer.
 *
 * 2. Redistributions in binary form must reproduce the above copyright
 *    notice, this list of conditions and the following disclaimer in the
 *    documentation and/or other materials provided with the distribution.
 *
 * THIS SOFTWARE IS PROVIDED BY RIKEN AND CONTRIBUTORS ``AS IS'' AND ANY
 * EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE
 * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
 * PURPOSE ARE DISCLAIMED.  IN NO EVENT SHALL RIKEN OR CONTRIBUTORS BE
 * LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR
 * CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF
 * SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS
 * INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN
 * CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE)
 * ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF
 * THE POSSIBILITY OF SUCH DAMAGE.
 */

/* $SATELLITE: satellite4/modules/bps/command/learnlib.cpp,v 1.5 2005/02/22 07:39:59 ninja Exp $ */

#ifdef HAVE_CONFIG_H
# include "config.h"
#endif
#include <stdio.h>
#include <stdlib.h>
#include <string.h>

#ifdef HAVE_UNISTD_H
# include <unistd.h>
#endif
#ifdef HAVE_IO_H
# include <io.h>
#endif

#include "libbps.h"

/*************************************************
 *                                               *
 *       Back Propergation Simurator(BPS)        *
 *             subroutine package                *
 *               Ver                             *
 *         coded         in Nov.16 1989          *
 *         modified by   K.Kuroda                *
 *                                               *
 *************************************************
 *                                               *
 *       filename learnlib.c                     *
 *                                               *
 *************************************************/

#ifdef __cplusplus
extern "C" {
#endif


/************************************************
  workspace initialize
  ************************************************/
void workspace_initialize()
{
  int      lay, unit;
  bps_ilin_t  *linkpoint;

  for (lay = 0; lay < bps_cont.NumOfLayer; lay++)
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      linkpoint = Getintoplist(bps_cont.BPNet[lay][unit].CellNode);
      while (linkpoint != NULL) {
        linkpoint->WgtWork = 0.0;
        linkpoint          = Getinfwdlist(linkpoint);
      }
    }
}


/************************************************
  system initialize
  ************************************************/
int system_initialize()
{
  /* PARAMETER INITIALIZE */
  if (bps_cont.ErrStorInterval == 0)  bps_cont.ErrStorInterval = 1;
  if (bps_cont.WgtStorInterval == 0)  bps_cont.WgtStorInterval = 1;
  if (bps_cont.DisplayInterval == 0)  bps_cont.DisplayInterval = 1;

  /* SET NumOfLink, WgtBlockSize AND  ErrBlockSize */
  SetNumOfLink();

  /* SET PATTERN No. */
  if ((bps_cont.LrnFirstPtrn == 0) || (bps_cont.LrnLastPtrn == 0)) {
    bps_cont.LrnFirstPtrn = 1;
    bps_cont.LrnLastPtrn  = GetNumOfRecord(bps_cont.InputFile);
  }

  bps_cont.NumOfPtrn = bps_cont.LrnLastPtrn - bps_cont.LrnFirstPtrn + 1;

  /* SET INPUT DATA AND TEACH DATA */
  if (access(bps_cont.InputFile, 0) == -1)
		return 35; /* Input File Isn't Exist */
  if (access(bps_cont.TeachFile, 0) == -1)
		return 36; /* Teach File Isn't Exist */

  bps_cont.InputData = (float**)bps_malloc2D(bps_cont.NumOfPtrn, 
                                             bps_cont.NumOfCell[0], 
                                             sizeof(float));
  if (bps_cont.InputData == NULL)
		return 41; /* Can't Allocate To \"InputData\" */

  bps_cont.TeachData = (float**)bps_malloc2D(bps_cont.NumOfPtrn,
                                             bps_cont.NumOfCell
                                             [bps_cont.NumOfLayer-1],
                                             sizeof(float));
  if (bps_cont.TeachData == NULL)
		return 42; /* Can't Allocate To \"TeachData\" */

  printf("before ReadData routine\n");
  printf("InputFile = %s\n", bps_cont.InputFile);
  printf("TeachFile = %s\n", bps_cont.TeachFile);

  printf("LrnFirstPtrn = %d\n", bps_cont.LrnFirstPtrn);
  printf("LenLastPtrn  = %d\n", bps_cont.LrnLastPtrn);

  ReadData(bps_cont.InputFile, bps_cont.InputData, bps_cont.LrnFirstPtrn,
           bps_cont.LrnLastPtrn, bps_cont.NumOfCell[0]);
  ReadData(bps_cont.TeachFile, bps_cont.TeachData, bps_cont.LrnFirstPtrn,
           bps_cont.LrnLastPtrn, bps_cont.NumOfCell[bps_cont.NumOfLayer-1]);

  /* ALLOCATE TO OutCellErr */
  bps_cont.OutCellErr = (double*)emalloc((bps_cont.NumOfCell[bps_cont.NumOfLayer-1]+1)*sizeof(double));
  if (bps_cont.OutCellErr == NULL)
		return 37; /* Can't Allocate To \"OutCellErr\" */

  /* ALLOCATE TO ErrBuffer  */
  bps_cont.ErrBuffer = (float*)emalloc((bps_cont.MaxLearnCount+1)*sizeof(float));
  if(bps_cont.ErrBuffer == NULL)
		return 37; /* Can't Allocate To \"OutCellErr\" */
	memset(bps_cont.ErrBuffer, 0, sizeof(float)*(bps_cont.MaxLearnCount+1));

  bps_cont.buff_err = (double*)emalloc((bps_cont.MaxLearnCount+1)*
                                      sizeof(double));
  if(bps_cont.buff_err == NULL)
		return 37; /* Can't Allocate To \"OutCellErr\" */
	memset(bps_cont.buff_err, 0, sizeof(double)*(bps_cont.MaxLearnCount+1));

	return 0;
}


/************************************************
  system end
  ************************************************/
void system_end()
{
  efree(bps_cont.OutCellErr);
  bps_free2D((char**)bps_cont.InputData);
  bps_free2D((char**)bps_cont.TeachData);
}


/************************************************
  forward larning routine
  inputs:
  inputdata : input data
  teachdata : teach data
  ************************************************/
double forward_learn(float *inputdata, float *teachdata)
{

  int       lay, unit;
  double    net, err_tmp;
  bps_cel_t    *cellpoint, *cur_cel;
  bps_ilin_t   *cur_ilin;

  for (unit = 1; unit <= bps_cont.NumOfCell[0]; unit++) {
    cur_cel         = bps_cont.BPNet[0][unit].CellNode;
    cur_cel->Net    = (double) inputdata[unit-1];
    cur_cel->Active = FuncSelect(cur_cel->CharFunc, cur_cel->Net);
  }
  for (lay = 1; lay < bps_cont.NumOfLayer; lay++) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      net = 0.0;
      cur_cel  = bps_cont.BPNet[lay][unit].CellNode;
      cur_ilin = Getintoplist(cur_cel);
      while (cur_ilin != NULL) {
        cellpoint  = cur_ilin->InputCell;
        net       += cur_ilin->Weight * cellpoint->Active;
        cur_ilin   = Getinfwdlist(cur_ilin);
      }
      cur_cel->Net    = net;
      cur_cel->Active = FuncSelect(cur_cel->CharFunc, net);
    }
  }

  err_tmp = 0.0;
  for (unit = 1; unit <= bps_cont.NumOfCell[bps_cont.NumOfLayer-1]; unit++) {
    cellpoint        = bps_cont.BPNet[bps_cont.NumOfLayer-1][unit].CellNode;
    cellpoint->Delta = (double)(teachdata[unit-1]) - cellpoint->Active;

    bps_cont.OutCellErr[unit-1] += cellpoint->Delta * cellpoint->Delta;
    err_tmp += cellpoint->Delta * cellpoint->Delta;
  }
  return (err_tmp);
}


/************************************************
  backward larning routine
  ************************************************/
void backward_learn()
{
  int      lay, unit;
  double   delta_tmp;
  bps_cel_t   *cellpoint, *cur_cel;
  bps_ilin_t  *cur_ilin;
  bps_olin_t  *cur_olin;

  for (unit = 1; unit <= bps_cont.NumOfCell[bps_cont.NumOfLayer-1]; unit++) {
    cur_cel        = bps_cont.BPNet[bps_cont.NumOfLayer-1][unit].CellNode;
    cur_cel->Delta = (cur_cel->Delta
                      * DiffSelect(cur_cel->CharFunc, cur_cel->Net));
    cur_ilin       = Getintoplist(cur_cel);
    while (cur_ilin != NULL) {
      cellpoint          = cur_ilin->InputCell;
      cur_ilin->WgtWork += cur_cel->Delta * cellpoint->Active;
      cur_ilin           = Getinfwdlist(cur_ilin);
    }
  }

  for (lay = bps_cont.NumOfLayer-2; lay > -1; lay--) {
    for (unit = 1; unit <= bps_cont.NumOfCell[lay]; unit++) {
      delta_tmp = 0.0;
      cur_cel   = bps_cont.BPNet[lay][unit].CellNode;
      cur_olin  = Getouttoplist(cur_cel);
      while (cur_olin != NULL) {
        cur_ilin   = Getouttolist(cur_olin);
        cellpoint  = cur_ilin->NodeCell;
        delta_tmp += cellpoint->Delta * cur_ilin->Weight;
        cur_olin   = Getoutfwdlist(cur_olin);
      }
      cur_cel->Delta = (delta_tmp
                        * DiffSelect(cur_cel->CharFunc, cur_cel->Net));
      cur_ilin       = Getintoplist(cur_cel);
      while (cur_ilin != NULL) {
        cellpoint          = cur_ilin->InputCell;
        cur_ilin->WgtWork += cur_cel->Delta * cellpoint->Active;
        cur_ilin           = Getinfwdlist(cur_ilin);
      }
    }
  }
}

#ifdef __cplusplus
}
#endif
