//********************************************************//
//
//  Class   :   BPNN
//  Author  :   Keith A. Pray
//  Date    :   October 2000 (ported)
//
//********************************************************//

//import java.util.Vector;

import java.util.Random;

/**
 * This class represent a back propogation neural network.
 *
 * @author Keith A. Pray
 * @version 1.0, September 2000
 */

public class BPNN
{

//********************************************************//
// ----- Data Members -----
//********************************************************//

  /** number of input nodes */
  public int input_n;

  /** number of hidden nodes */
  public int hidden_n;

  /** number of output nodes */
  public int output_n;

  /** the input units */
  public double[] input_units;

  /** the hidden units */
  public double[] hidden_units;

  /** the output units */
  public double[] output_units;
 
  /** storage for hidden units error */
  public double[] hidden_delta;
 
  /** storage for output units error */
  public double[] output_delta;

  /** storage for target */
  public double[] target;

  /** weights from input to hidden layer */
  public double[][] input_weights;
  
  /** weights from hidden to output layer */
  public double[][] hidden_weights;
  
  // The next two are for momentum
  
  /** previous change on input to hidden wgt */
  public double[][] input_prev_weights;
  
  /** previous change on hidden to output wgt */
  public double[][] hidden_prev_weights;

  // Defaults
  public final int inDefault = 3;
  public final int hiddenDefault = 3;
  public final int outDefault = 3;

  /** Random number generator */
  private Random random;

  /** should we initialize to zero? */
  public boolean INITZERO;

//********************************************************//
// ----- Methods-----
//********************************************************//

  /**
   * Constructs a new BPNN object with default values
   */
  
  BPNN()
  {
    // use the defaults for number of in node, hidden nodes
    // and out nodes
    Init ( inDefault, hiddenDefault, outDefault );

  } // END BPNN()

//********************************************************//

  /**
   * Constructs a new BPNN object with specified values
   * for the number of nodes.
   *
   * @param in
   * number of input nodes
   *
   * @param hidden
   * number of hidden nodes
   *
   * @param out
   * number of output nodes 
   *
   */
  
  BPNN ( int in, int hidden, int out )
  {
    // make sure none of them are zero or less
    if ( ( in <= 0 ) || ( hidden <= 0 ) || ( out <= 0 ) )
    {
      // output error message
      System.err.println ( "cannot have zero or less nodes in any level\n" );

      return;

    } // end if any param is zero or less

    Init ( in, hidden, out );

  }	// END BPNN()

//********************************************************//

  private void Init ( int in, int hidden, int out )
  {
    // set to not use zero for initial values by default
    INITZERO = false;

    // make the random number generator to use
    random = new Random();

    input_n = in;
    hidden_n = hidden;
    output_n = out;
    
    // make arrays to hold stuff
    input_units = new double[in +1];
    hidden_units = new double[hidden +1];
    output_units = new double[out +1];
      
    hidden_delta = new double[hidden +1];
    output_delta = new double[out +1];
    target = new double[out +1];
      
    input_weights = new double[in + 1][hidden + 1];
    hidden_weights = new double[hidden + 1][out + 1];
      
    input_prev_weights = new double[in + 1][hidden + 1];
    hidden_prev_weights = new double[hidden + 1][out + 1];

  } // END void Init ( int in, int hidden, int out )

//********************************************************//

  /** 
   * Creates a new fully-connected network from scratch,
   * with the given numbers of input, hidden, and output units.
   * Threshold units are automatically included.  All weights are
   * randomly initialized.
   *   
   * Space is also allocated for temporary storage (momentum weights,
   * error computations, etc).
   */

  private void InitWeights()
  {
    if ( INITZERO )
    {
      bpnn_zero_weights ( input_weights, input_n, hidden_n );
    }
    else
    {
      bpnn_randomize_weights ( input_weights, input_n, hidden_n );
    }
    
    bpnn_randomize_weights ( hidden_weights, hidden_n, output_n );
    
    bpnn_zero_weights ( input_prev_weights, input_n, hidden_n );
    bpnn_zero_weights ( hidden_prev_weights, hidden_n, output_n );
    
  } // END void bpnn_create()
  
//********************************************************//

/** Seed the random number generator **/

void bpnn_initialize ( long seed )
{
  System.out.println ( "Random number generator seed: " +
		       seed + "\n" );

  random.setSeed ( seed );

} // END void bpnn_initialize ( seed )

//********************************************************//
  
  /** 
   * initializes weights in a 2D array to random values 
   * -1.0 to 1.0
   */
  
  private void bpnn_randomize_weights ( double[][]w, int m, int n )
  {
    int i, j;
    
    for ( i = 0; i <= m; i++ )
    {
      for ( j = 0; j <= n; j++ )
      {
        w[i][j] = ( ( random.nextDouble() * 2 ) - 1 );
      }
    }
    
} // END void bpnn_randomize_weights ( w, m, n )
  
//********************************************************//

  /**
   * initialize a 2D array to 0.0
   */
  
  private void bpnn_zero_weights ( double[][]w, int m, int n )
  {
    int i, j;
    
    for ( i = 0; i <= m; i++ )
    {
      for ( j = 0; j <= n; j++ )
      {
        w[i][j] = 0.0;
      }
    }
    
  } // END void bpnn_zero_weights ( w, m, n )
  
//********************************************************//

} // END public class BPNN
