//********************************************************//
//
//  Class   :   BackProp
//  Author  :   Keith A. Pray
//  Date    :   October 2000 (ported)
//
//********************************************************//

import BPNN;
import java.lang.Math;

/** 
 * So we have to use this code to build a neural network
 * for the census income data etc etc.
 * 
 * Since I have all the code for file parsing and fancy data
 * type handling all written in Java, it seems far easier
 * to port this code to Java. Not only does this the tedius
 * writing of file parsing stuff for this particular application,
 * it forces me to examin the back propogation method in great
 * detail.
 *
 * @author Keith A. Pray
 * @version 1.0, September 2000
 */

/*
 ********************************************************
 * HISTORY
 * 15-Oct-94  Jeff Shufelt (js), Carnegie Mellon University
 *      Prepared for 15-681, Fall 1994.
 *
 * Tue Oct  7 08:12:06 EDT 1997, bthom, added a few comments,
 *       tagged w/bthom
 *
 ********************************************************
 */

public class BackProp
{

//********************************************************//
// ----- Data Members -----
//********************************************************//

  /** The neural network */
  BPNN net;
 
//********************************************************//
// ----- Methods-----
//********************************************************//

  /**
   * Constructs a new BackProp object
   */

  BackProp()
  {
    // do nothing for now

  }	// END BackProp()

//********************************************************//

  /**
   * The squashing function.  Currently, it's a sigmoid
   * as in formula 4.12 and illustrated in figure 4.6 
   * from text.
   *
   * @param x
   * the value to squash
   *
   * @return
   * a value based on the sigmoid function between 0 and 1
   */
  
  double squash ( double x )
  {
    /** this formula looks correct **/
    return ( 1.0 / ( 1.0 + Math.exp ( -x ) ) );
    
  }	// END squash ( double x )
  
//********************************************************//

//********************************************************//

//********************************************************//
  
//********************************************************//
  
  public void bpnn_layerforward ( double[]l1, double[]l2, 
				  double[][]conn, 
				  int n1, int n2 )
  {
    double sum;
    int j, k;
    
    /*** Set up thresholding unit ***/
    l1[0] = 1.0;
    
    /*** For each unit in second layer ***/
    for ( j = 1; j <= n2; j++ )
    {
      /*** Compute weighted sum of its inputs ***/
      sum = 0.0;
      
      /** for each unit in layer 1 **/
      for ( k = 0; k <= n1; k++ ) 
      {
	/** corresponds to Figure 4.6, sum of weighted inputs to
	 ** unit. conn[k][j] contains the weight of the connection
	 ** from the j-th first layer unit to the k-th second
	 ** layer unit. l1[k] contains the input from the k-th 
	 ** unit from layer 1.
	 **/
	sum += conn[k][j] * l1[k];
      }
      
      /** this computes the sigma as in Formula 4.12 (Figure 4.6) **/
      l2[j] = squash ( sum );
    }
    
  } // END void bpnn_layerforward ( l1, l2, conn, n1, n2 )
  
//********************************************************//

  /**
   *  Calculates the error for the output units 
   */ 

  public void bpnn_output_error ( double[] delta, 
				  double[] target, 
				  double[] output, 
				  int nj, double err )
  {
    int j;
    double o, t, errsum;
    
    /** clean slate,no errors yet **/
    errsum = 0.0;
    
    /** for each output unit **/
    for ( j = 1; j <= nj; j++ )
    {
      /** get the output value **/
      o = output[j];
      
      /** get the target value for this **/
      t = target[j];
      
      /** Formula T4.3 from Table 4.2 in text.
       ** and to the formula 4.26 in the text
       ** -(t_j - o_j) o_j (1 - o_j)
       ** Which is part of the stochastic gradient descent rule
       ** for output units.
       **/
      delta[j] = o * ( 1.0 - o ) * ( t - o );
      
      /** sum up the error to get ready for calculating
       ** the error for the hidden units.
       ** Note: In the Table 4.2 it shows this sum as
       ** sum ( delta[j] * weight[kh] ) where h is the hidden unit
       ** and weight[kh] is the weight of the 
       ** line hidden -> output unit.
       ** but this sum does not include the weight factor.
       ** Since the error is summarized here into a single value,
       ** check carefully that the bpnn_hidden_error function
       ** somehow accounts for this. 
       ** ...
       ** Ok, in the error function for the hidden units,
       ** the delta[j] from here is being used along with the 
       ** weight_kh. The error sum being calculated here is simply
       ** for reporting purposes. Possibly used by a driver that
       ** uses this error for determining when to stop the 
       ** learning process. I guess I am easy to confuse...
       **/
      errsum += Math.abs ( delta[j] );
    }
    
    err = errsum;
    
  } /** END void bpnn_output_error ( delta, target, output, nj, err ) **/
  
//********************************************************//

  /** 
   * Calculates the error for the hidden units 
   *
   * @param who
   * the array of weights for the hidden units
   */

  public void bpnn_hidden_error ( double[] delta_h, int nh, 
				  double[] delta_o, int no, 
				  double[][] who, double[] hidden, 
				  double err )
  {
    int j, k;
    double h;
    double sum = 0;
    double errsum;
    
    /** start with no error **/
    errsum = 0.0;
    
    /** for each hidden unit **/
    for (j = 1; j <= nh; j++) 
    {
      /** get the hidden unit value **/
      h = hidden[j];
      
      /** init the sum 
	  sum = 0.0;
	  
	  /** for each output unit **/
      for (k = 1; k <= no; k++) 
      {
	/** It looks like something is wrong here.
	 ** Instead of using the error term for the output unit,
	 ** the delta k of the output is being used.
	 **/
	sum += delta_o[k] * who[j][k];
      }
      
      /** This corresponds to T4.4 in Table 4.2 **/
      delta_h[j] = h * (1.0 - h) * sum;
      
      /** maintain the error sum for reporting **/
      errsum += Math.abs ( delta_h[j] );
    }
    
    err = errsum;
    
  } // END void bpnn_hidden_error ( delta_h, nh, delta_o, no, 
    //				   who, hidden, err ) 

//********************************************************//

  /** 
   * Adjusts the weights for the given units 
   *
   * @param delta
   * the delta calculated in error functions
   * 
   * @param ndelta
   * number of deltas - should be same as input units
   *
   * @param ly
   * the input (un-weighted)
   *
   * @param nly
   * number of inputs
   *
   * @param w
   * the weights (current) to be adjusted
   *
   * @param oldw
   * the old weights
   *
   * @param eta
   * learning rate
   *
   * @param momentum
   * influence (weight) factor of previous weights
   *  
   */

  public void bpnn_adjust_weights ( double[] delta, int ndelta, 
				    double[] ly, int nly, 
				    double[][] w, 
				    double[][] oldw, 
				    double eta, 
				    double momentum )
  {
    double new_dw;
    int k, j;
    
    /** So, this is odd. Since most of the arrays are indexed
     ** starting from 1.
     ** So the loop for each input starts at 0...
     ** So there is an extra weight in w which doesn't seem to 
     ** correspond to any unit.
     **/
    ly[0] = 1.0;
    
    /** for each delta value **/
    for ( j = 1; j <= ndelta; j++ )
    {
      /** for each input **/
      for ( k = 0; k <= nly; k++ )
      {
	/** calculates the delta to be used to update the weight
	 ** the first part ( eta * delta[j] * ly[k] )
	 ** corresponds to T4.5 from Table 4.2
	 ** with the second part, the whole statement below
	 ** corresponds to Equation 4.18 in the text.
	 **/
	new_dw = ( ( eta * delta[j] * ly[k] ) + ( momentum * oldw[k][j] ) );
	
	/** add the delta weight to the current weight **/
	w[k][j] += new_dw;
	
	/** save this weight for later use in momentum **/
	oldw[k][j] = new_dw;
      }
    }
  } // END void bpnn_adjust_weights ( delta, ndelta, ly, nly, w, oldw, 
    //				      eta, momentum )
  
//********************************************************//

  public void bpnn_feedforward ( BPNN net )
  {
    int in, hid, out;
    
    in = net.input_n;
    hid = net.hidden_n;
    out = net.output_n;
    
    /*** Feed forward input activations. ***/
    bpnn_layerforward ( net.input_units, net.hidden_units,
			net.input_weights, in, hid );
    
    bpnn_layerforward ( net.hidden_units, net.output_units,
			net.hidden_weights, hid, out );
    
  } // END void bpnn_feedforward ( BPNN net )

//********************************************************//


/** Trains the neural network **/

  public void bpnn_train ( BPNN net, double eta, double momentum, 
			   double eo, double eh )
  {
    int in, hid, out;
    double out_err = 0;
    double hid_err = 0;
    
    in = net.input_n;
    hid = net.hidden_n;
    out = net.output_n;
    
    /*** Feed forward input activations. ***/
    bpnn_layerforward ( net.input_units, net.hidden_units,
			net.input_weights, in, hid );
    
    bpnn_layerforward ( net.hidden_units, net.output_units,
			net.hidden_weights, hid, out );
    
    /*** Compute error on output and hidden units. ***/
    bpnn_output_error ( net.output_delta, net.target, net.output_units,
			out, out_err );
    
    bpnn_hidden_error ( net.hidden_delta, hid, net.output_delta, out,
			net.hidden_weights, net.hidden_units, hid_err );
    eo = out_err;
    eh = hid_err;
    
    /*** Adjust input and hidden weights. ***/
    
    bpnn_adjust_weights ( net.output_delta, out, net.hidden_units, hid,
			  net.hidden_weights, net.hidden_prev_weights, 
			  eta, momentum );
    
    bpnn_adjust_weights ( net.hidden_delta, hid, net.input_units, in,
			  net.input_weights, net.input_prev_weights, 
			  eta, momentum );
    
  } // END void bpnn_train ( net, eta, momentum, eo, eh )
  
//********************************************************//

/** Saves a neural network to a file for later retrieval **/
  /*
  public void bpnn_save ( BPNN net, String filename )
  {
    int fd, n1, n2, n3, i, j, memcnt;
    double dvalue;
    double[][] w;
    String mem;
    
    if ( ( fd = creat ( filename, 0644 ) ) == -1 ) 
    {
      printf ( "BPNN_SAVE: Cannot create '%s'\n", filename );
      return;
    }
    
    n1 = net->input_n;  n2 = net->hidden_n;  n3 = net->output_n;
    
    printf ( "Saving %dx%dx%d network to '%s'\n", n1, n2, n3, filename );
    fflush ( stdout );
    
    write ( fd, (char *) &n1, sizeof(int) );
    write ( fd, (char *) &n2, sizeof(int) );
    write ( fd, (char *) &n3, sizeof(int) );
    
    memcnt = 0;
    w = net->input_weights;
    mem =  malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
    for ( i = 0; i <= n1; i++ ) 
    {
      for ( j = 0; j <= n2; j++ ) 
      {
	dvalue = w[i][j];
	fastcopy ( &mem[memcnt], &dvalue, sizeof(double) );
	
	memcnt += sizeof(double);
      }
    }
    write ( fd, mem, (n1+1) * (n2+1) * sizeof(double) );
    free ( mem ) ;
    
    memcnt = 0;
    w = net.hidden_weights;
    mem = (char *) malloc ((unsigned) ((n2+1) * (n3+1) * sizeof(double)));
    for (i = 0; i <= n2; i++) 
    {
      for (j = 0; j <= n3; j++) 
      {
	dvalue = w[i][j];
	fastcopy(mem[memcnt], dvalue, sizeof(double));
	memcnt += sizeof(double);
      }
    }
    write(fd, mem, (n2+1) * (n3+1) * sizeof(double));
    free ( mem );
    
    close ( fd );
    return;
  */    
  //  } // END void bpnn_save ( net, filename )
  
//********************************************************//

/** Reads in a neural network from a file **/
  /*
BPNN *bpnn_read ( filename )
     char *filename;
{
  char *mem;
  BPNN *new;
  int fd, n1, n2, n3, i, j, memcnt;

  if ( ( fd = open ( filename, 0, 0644 ) ) == -1 ) 
  {
    return ( NULL );
  }

  printf ( "Reading '%s'\n", filename );  
  fflush ( stdout );

  read ( fd, (char *) &n1, sizeof(int) );
  read ( fd, (char *) &n2, sizeof(int) );
  read ( fd, (char *) &n3, sizeof(int) );
  new = bpnn_internal_create ( n1, n2, n3 );

  printf ( "'%s' contains a %dx%dx%d network\n", filename, n1, n2, n3 );
  printf ( "Reading input weights..." );  
  fflush ( stdout );

  memcnt = 0;
  mem = (char *) malloc ((unsigned) ((n1+1) * (n2+1) * sizeof(double)));
  read(fd, mem, (n1+1) * (n2+1) * sizeof(double));
  for ( i = 0; i <= n1; i++ )
  {
    for ( j = 0; j <= n2; j++ ) 
    {
      fastcopy ( &( new->input_weights[i][j] ), &mem[memcnt], sizeof(double) );
      memcnt += sizeof(double);
    }
  }
  free ( mem );

  printf ( "Done\nReading hidden weights..." );
  fflush ( stdout );

  memcnt = 0;
  mem = (char *) malloc ((unsigned) ( (n2+1) * (n3+1) * sizeof(double) ) );
  read ( fd, mem, (n2+1) * (n3+1) * sizeof(double) );
  for ( i = 0; i <= n2; i++ )
  {
    for ( j = 0; j <= n3; j++ ) 
    {
      fastcopy ( &( new->hidden_weights[i][j] ), 
		 &mem[memcnt], sizeof(double) );
      memcnt += sizeof(double);
    }
  }
  free ( mem );
  close ( fd );

  printf ( "Done\n" );
  fflush(stdout);

  bpnn_zero_weights ( new->input_prev_weights, n1, n2 );
  bpnn_zero_weights ( new->hidden_prev_weights, n2, n3 );

  return ( new );
  */
  //} /** END BPNN *bpnn_read(filename) **/

//********************************************************//

}	// END class BackProp

//********************************************************//
