
import weka.classifiers.*;

import weka.filters.*;
import weka.core.*;
import java.util.*;

/**
 * Implements the back propagation algorithm
 * Globally replaces all missing values with medium or mean
 * from data set.
 *
 * Ported from: backprop.c, backprop.h and facetrain.c
 *
 * 2000-12-11 - Changing default settings to those that produced 
 * best results to date. -kap
 *
 *  ******************************************************************
 *  * HISTORY
 *  * 15-Oct-94  Jeff Shufelt (js), Carnegie Mellon University
 *  *      Prepared for 15-681, Fall 1994.
 *  *
 *  ******************************************************************
 *
 * Valid options are:<p>
 *
 * -I num <br>
 * The number of iterations through training data. (default 6000)<p>
 *
 * -E num <br>
 * The learning rate. (default 0.0112)<p>
 *
 * -S num <br>
 * The seed for the random number generator. (default system time)<p>
 *
 * -H num <br>
 * The number of hidden nodes to be use. (default half the output nodes)<p>
 *
 * -M num <br>
 * The momentum. (default 0.01)
 *
 * @author Ported by: Keith A. Pray
 *
 */

public class BackPropagation extends DistributionClassifier 
    implements OptionHandler 
{
  /** The number of iterations */
  private int m_NumIterations = 6000;
  
  /** The learning rate */
  private double m_LearningRate = 0.0112;
  
  /** The training instances */
  private Instances m_Train = null;
  
  /** Seed used for shuffling the dataset */
  private int m_Seed = 1;

  /** The momentum of current movement in the serach space */
  private double m_Momentum = 0.01;

  /** The number of iterations before printing out status */
  private int statusFrequency = 100;
  
  /** The random number generator **/
  private Random random = null;
  
  /** The filter used to make attributes numeric. */
  /** using my own specialized version */
  private NominalToBinaryFilter m_NominalToBinary =null;
  
  /** The filter used to get rid of missing values. */
  private ReplaceMissingValuesFilter m_ReplaceMissingValues =null;
  
  /** The number of input nodes */
  private int input_n = 0 ;
  
  /** The number of hidden nodes */
  private int hidden_n = 0;
  
  /** The number of output nodes */
  private int output_n = 0;
  
  /** The input nodes */
  private double[] input_nodes = null;
  
  /** The hidden nodes */
  private double[] hidden_nodes = null;
  
  /** The output nodes */
  private double[] output_nodes = null;
  
  /** The storage for hidden node error */
  private double[] hidden_delta = null;
  
  /** The storage for output node error */
  private double[] output_delta = null;
  
  /** The storage for targets */
  private double[] target = null;
  
  /** The weights from input to hidden nodes */
  private double[][] input_weights = null;
  
  /** The weights from hidden to output nodes */
  private double[][] hidden_weights = null;
  
  // The next two are for momentum
  
  /** The previous weights from input to hidden nodes */
  private double[][] input_prev_weights = null;
  
  /** The previous weights from hidden to output nodes */
  private double[][] hidden_prev_weights = null;
  
//********************************************************//
  
  /**
   * Returns an enumeration describing the available options
   *
   * @return an enumeration of all the available options
   */
  
  public Enumeration listOptions() 
  {
    Vector newVector = new Vector ( 5 );
    
    newVector.addElement ( new Option ( "\tThe number of iterations to be performed.\n"
					+ "\t(default 6000)",
					"I", 1, "-I <int>" ) );

    newVector.addElement ( new Option ( "\tThe learning rate.\n"
					+ "\t(default 0.0112)",
					"E", 1, "-E <double>" ) );

    newVector.addElement ( new Option ( "\tThe seed for the random number generation.\n"
					+ "\t(default current time)",
					"S", 1, "-S <int>" ) );

    newVector.addElement ( new Option ( "\tThe number of hidden nodes to be use.\n"
					+ "\t(default half the output nodes)",
					"H", 1, "-H <int>" ) );

    newVector.addElement ( new Option ( "\tThe momentum.\n"
					+ "\t(default 0.01)",
					"M", 1, "-M <double>" ) );
    
    return newVector.elements();

  } // END public Enumeration listOptions()
  
//********************************************************//
  
  /**
   * Parses a given list of options. Valid options are:<p>
   *
   * -I num <br>
   * The number of iterations to be performed. (default 6000)<p>
   *
   * -E num <br>
   * The learning rate. (default 0.0112)<p>
   *
   * -S num <br>
   * The seed for the random number generator. (default current time)<p>
   *
   * -H num <br>
   * The number of hidden nodes to be used. (default half the output nodes) <p>
   *
   * -M num <br>
   * The momentum of current direction to be used. (default 0.01) <p>
   *
   * @param options the list of options as an array of strings
   * @exception Exception if an option is not supported
   */
  
  public void setOptions ( String[] options ) throws Exception 
  {  
    String iterationsString = Utils.getOption ( 'I', options );
    
    if (iterationsString.length() != 0) 
    {
      m_NumIterations = Integer.parseInt ( iterationsString );
    } 
    else 
    {
      m_NumIterations = 6000;
    }
    
    String learningRateString = Utils.getOption ( 'E', options );
    
    if (learningRateString.length() != 0) 
    {
      m_LearningRate = ( new Double ( learningRateString ) ).doubleValue();
    } 
    else 
    {
      m_LearningRate = 0.0112;
    }
    
    String seedString = Utils.getOption ( 'S', options );
    
    if ( seedString.length() != 0 )
    {
      m_Seed = Integer.parseInt ( seedString );
    } 
    else 
    {
      Date date = new Date();
      m_Seed = (int)(date.getTime());
    }
    
    String hiddenString = Utils.getOption ( 'H', options );
    
    if ( hiddenString.length() != 0 )
    {
      hidden_n = Integer.parseInt ( hiddenString );
    } 
    
    String momentumString = Utils.getOption ( 'M', options );
    
    if ( momentumString.length() != 0 )
    {
      m_Momentum =  ( new Double ( momentumString ) ).doubleValue();
    } 
    else 
    {
      m_Momentum = 0.01;
    }

  } // END public void setOptions ( String[] options ) 
    // throws Exception

//********************************************************//

  /**
   * Gets the current settings of the classifier.
   *
   * @return an array of strings suitable for passing to setOptions
   */
  
  public String[] getOptions() 
  {
    String[] options = new String [10];
    int current = 0;
    
    options[current++] = "-I"; options[current++] = "" + m_NumIterations;
    options[current++] = "-E"; options[current++] = "" + m_LearningRate;
    options[current++] = "-S"; options[current++] = "" + m_Seed;
    options[current++] = "-H"; options[current++] = "" + hidden_n;
    options[current++] = "-M"; options[current++] = "" + m_Momentum;
    
    while ( current < options.length )
    {
      options[current++] = "";
    }
    return options;

  } // END public String[] getOptions()
  
//********************************************************//

  /**
   * Builds the Neural Network with Back Propagation.
   *
   * @exception Exception if something goes wrong during building
   */
  
  public void buildClassifier ( Instances insts ) 
    throws Exception 
  {
    if ( insts.checkForStringAttributes() )
    {
      throw new Exception ( "Can't handle string attributes!" );
    }
    if ( insts.numClasses() > 2 )
    {
      throw new Exception ( "Can only handle two-class datasets!" );
    }
    if ( insts.classAttribute().isNumeric() )
    {
      throw new Exception ( "Can't handle a numeric class!" );
    }
    
    // Filter data
    // make the training set
    m_Train = new Instances ( insts );

    // can't use a training example if it doesn't have a class
    m_Train.deleteWithMissingClass();

    // take care of missing values
    m_ReplaceMissingValues = new ReplaceMissingValuesFilter();
    m_ReplaceMissingValues.inputFormat ( m_Train );
    m_Train = Filter.useFilter ( m_Train, m_ReplaceMissingValues );
    
    // change nominal values to binary so it will
    // work with our neural network

    // Note: using binary representation of nominal values
    // since numerating them would introduce a bias when training
    // the net

    m_NominalToBinary = new NominalToBinaryFilter();
    m_NominalToBinary.inputFormat ( m_Train );

    m_Train = Filter.useFilter ( m_Train, m_NominalToBinary );

    // let's output a summary of what the data now looks like
//    System.out.println ( "Summary of data transformation: \n" );
//    System.out.println ( m_Train.toSummaryString() + "\n" );
//    System.out.println ( "Example: " +
//    			 m_Train.firstInstance().toString() + "\n" );
//    System.out.flush();

    /** Randomize training data */
    m_Train.randomize ( new Random ( m_Seed ) );
    
    // set seed for random number generator
    random = new Random ( m_Seed );
    
    // Make space to store neural net
    
    // need to know how many nodes and such
    
    // the number of input nodes will be the number
    // of attributes for our data set
    // minus one for the class attribute
    input_n = m_Train.numAttributes() - 1;

    // the number of output nodes will be equal to the number
    // of possible values the class attribute can take
    // if it is nominal. 1 otherwise
    output_n = m_Train.classAttribute().numValues();

    // is the number of hidden nodes set yet?
    if ( hidden_n == 0  )
    {
      // let's round up, this is useful for the case
      // of 2 output nodes... might as well have
      // 2 hidden nodes :)
      hidden_n = (int)( Math.ceil ( ( output_n + 1) / 2.0 ) );

    } // end hidden not set

    // do we know how many hidden nodes yet?

    // the nodes
    input_nodes = new double[input_n + 1];
    hidden_nodes = new double[hidden_n + 1];
    output_nodes = new double[output_n + 1];
    
    // the deltas
    hidden_delta = new double[hidden_n + 1];
    output_delta = new double[output_n +1];
    
    // the target
    target = new double[output_n + 1];
    
    // the weights
    input_weights = new double[input_n + 1][hidden_n + 1];;
    hidden_weights = new double[hidden_n + 1][output_n + 1];;

    // the previous weights
    input_prev_weights = new double[input_n + 1][hidden_n + 1];;
    hidden_prev_weights = new double[hidden_n + 1][output_n + 1];;

    // initialize the weights
    // we will always use random initilizations    
    Randomize_Weights ( input_weights, input_n, hidden_n );
    Randomize_Weights ( hidden_weights, hidden_n, output_n );

    // while testing we'll use zero so we can compare more easily
//    Zero_Weights ( input_weights, input_n, hidden_n );
//    Zero_Weights ( hidden_weights, hidden_n, output_n );

    // need to initialize these to 0
    Zero_Weights ( input_prev_weights, input_n, hidden_n );
    Zero_Weights ( hidden_prev_weights, hidden_n, output_n );

    // print out the weights before training
    System.out.println ( toString() );

    /** Train the network */

    // this is like facetrain.c

    // also, print out the status according to status frequency
    int status = 1;

    for ( int it = 0; it < m_NumIterations; it++ )
    {
      double sumerr = 0.0;

      for ( int i = 0; i < m_Train.numInstances(); i++ )
      {
	Instance inst = m_Train.instance ( i );

	// as long as the class attribute is not missing,
	// train with this instance
	if ( ! inst.classIsMissing() ) 
	{
	  // train using back propagation algorithm
	  // and track error
	  sumerr += ( Train ( inst ) );
	  
      	} // end if
	else
	{
	  // output little message
	  System.err.println ( i + " :Missing Class Attribute\n" );
	}

      } // end for each instance

      if ( status >= statusFrequency )
      {
	// print out the error for this status
	System.out.println ( "Error sum for iteration " + (it + 1) +
			     " = " + sumerr + "\n" );
	// print out the network so far
//	System.out.println ( toString() );

	// reset status counter
	status = 1;

      } // end if time to print status
      else
      {
	// just increment the counter
	status++;
      }

    } // end for each iteration

  } // END  public void buildClassifier ( Instances insts ) 
    //    throws Exception

//********************************************************//

  /** 
   * Trains the neural network with the current input
   * and target.
   */

  private double Train ( Instance inst )
    throws Exception
  {
    // error reporting
    double out_err, hid_err = 0;

    // Set up input
    LoadInput ( inst );
    
    // Set up target
    LoadTarget ( inst );

    /*** Feed forward input activations. ***/
//    LayerForward ( input_nodes, hidden_nodes,
//		   input_weights, input_n, hidden_n );
    
//    LayerForward ( hidden_nodes, output_nodes,
//		   hidden_weights, hidden_n, output_n );
  
    FeedForward();
  
    /*** Compute error on output and hidden units. ***/
    out_err = OutputError();
    
    hid_err = HiddenError();
    
    /*** Adjust input and hidden weights. ***/
    
    AdjustWeights ( output_delta, output_n, hidden_nodes, hidden_n,
		    hidden_weights, hidden_prev_weights );
    
    AdjustWeights ( hidden_delta, hidden_n, input_nodes, input_n,
		    input_weights, input_prev_weights );

    // Have to return the errors somehow //
    // return the sum or errors since that is what
    // the training loop will do anyway
    return ( out_err + hid_err );

  } // END private double Train ( Instance inst)

//********************************************************//

  /** 
   * Adjusts the weights for the given units 
   *
   * @param delta
   * the delta calculated in error functions
   *
   * @param ndelta
   * number of deltas - should be same as input
   *
   * @param ly
   * the input (un-weighted)
   *
   * @param nly
   * number of inputs
   *
   * @param w
   * the weights (current) to be adjusted
   *
   * @param oldw
   * the old weights
   *
   */

  private void AdjustWeights ( double[] delta, int ndelta, 
			       double[] ly, int nly, 
			       double[][] w, 
			       double[][] oldw )
  {
    double new_dw = 1;
    int k, j;
    
    /** So, this is odd. Since most of the arrays are indexed
     ** starting from 1.
     ** So the loop for each input starts at 0...
     ** So there is an extra weight in w which doesn't seem to 
     ** correspond to any unit.
     **/
    ly[0] = 1.0;
    
    /** for each delta value **/
    for ( j = 1; j <= ndelta; j++ )
    {
      /** for each input **/
      for ( k = 0; k <= nly; k++ )
      {
	/** calculates the delta to be used to update the weight
	 ** the first part ( eta * delta[j] * ly[k] )
	 ** corresponds to T4.5 from Table 4.2
	 ** with the second part, the whole statement below
	 ** corresponds to Equation 4.18 in the text.
	 **/
	new_dw = ( ( m_LearningRate * delta[j] * ly[k] ) + 
		   ( m_Momentum * oldw[k][j] ) );
	
	/** add the delta weight to the current weight **/
	w[k][j] += new_dw;
	
	/** save this weight for later use in momentum **/
	oldw[k][j] = new_dw;
      }
    }

    // just testing
//    System.out.println ( "delta weight: " + new_dw );

  } // END void AdjustWeights ( double[] delta, int ndelta, 
    //				double[] ly, int nly, 
    //				double[][] w, 
    //				double[][] oldw )
  
//********************************************************//

  /** 
   * Calculates the error for the output units 
   *
   * @return
   * the error sum
   */ 

  private double OutputError()
  {
    int j;
    double o, t, errsum;
    
    /** clean slate, no errors yet **/
    errsum = 0.0;
    
    // for each output node

    for ( j = 1; j <= output_n; j++ )
    {
      /** get the output value **/
      o = output_nodes[j];
      
      /** get the target value for this **/
      t = target[j];
      
      /** Formula T4.3 from Table 4.2 in text.
       ** and to the formula 4.26 in the text
       ** -(t_j - o_j) o_j (1 - o_j)
       ** Which is part of the stochastic gradient descent rule
       ** for output units.
       **/
      output_delta[j] = o * ( 1.0 - o ) * ( t - o );
      
      /** sum up the error to get ready for calculating
       ** the error for the hidden units.
       ** Note: In the Table 4.2 it shows this sum as
       ** sum ( delta[j] * weight[kh] ) where h is the hidden unit
       ** and weight[kh] is the weight of the 
       ** line hidden -> output unit.
       ** But this sum does not include the weight factor.
       ** Since the error is summarized here into a single value,
       ** check carefully that the bpnn_hidden_error function
       ** somehow accounts for this. 
       ** ...
       ** Ok, in the error function for the hidden units,
       ** the delta[j] from here is being used along with the 
       ** weight_kh. The error sum being calculated here is simply
       ** for reporting purposes. Possibly used by a driver that
       ** uses this error for determining when to stop the 
       ** learning process. I guess I am easy to confuse...
       **/
      errsum += Math.abs ( output_delta[j] );

    } // edn for each output node
    
    return ( errsum );

  } // END private double OutputError()

//********************************************************//

  /** 
   * Calculates the error for the hidden units 
   *
   * @return
   * the error
   */

  private double HiddenError()
  {
    int j, k;
    double h, sum, errsum;
    
    /** start with no error **/
    errsum = 0.0;
    
    /** for each hidden unit **/
    for (j = 1; j <= hidden_n; j++) 
    {
      /** get the hidden unit value **/
      h = hidden_nodes[j];
  
      // init the sum 
      sum = 0.0;
      
      /** for each output unit **/
      for (k = 1; k <= output_n; k++) 
      {
	/** It looks like something is wrong here.
	 ** Instead of using the error term for the output unit,
	 ** the delta k of the output is being used.
	 **/
	sum += output_delta[k] * hidden_weights[j][k];
      }
      
      /** This corresponds to T4.4 in Table 4.2 **/
      hidden_delta[j] = h * ( 1.0 - h ) * sum;
      
      /** maintain the error sum for reporting **/
      errsum += Math.abs ( hidden_delta[j] );
    }
    
    return ( errsum );
    
  } // END private double HiddenError()

//********************************************************//

  /**
   * Outputs the distribution for the given output.
   *
   * Pipes output of SVM through sigmoid function.
   *
   * @param inst 
   * the instance for which distribution is to be computed
   *
   * @return 
   * the distribution
   *
   * @exception Exception if something goes wrong
   */

  public double[] distributionForInstance ( Instance inst ) 
    throws Exception 
  {

    // Filter instance
    m_ReplaceMissingValues.input ( inst );
    inst = m_ReplaceMissingValues.output();

    m_NominalToBinary.input ( inst );
    inst = m_NominalToBinary.output();
    
    // need counters
    int i = 0;

    // the array of probablities to return
    // remember that this should ** NOT ** be made with 
    // ( output_n +1 ) elements like the output node array
    double [] dis = new double[output_n];

    // Get probabilities
    // we will do this using feed forward through the neural net
    // our results will be stored in the output nodes
    
    // set the input nodes to contain the values of the 
    // instance
    LoadInput ( inst );

    // ok, now we are ready to feed forward this instance
    FeedForward();

    // printing out output nodes for testing only
//    System.out.print ( "Output Nodes for Testing " );

//    for ( i = 0; i <= output_n; i++ )
//    {
//      System.out.print ( " : " + output_nodes[i] );
//    }
    
//    System.out.println ( "\n" );

    // I think I'll have to shift all these over one since
    // we reference starting from index 1
    for ( i = 0; i < output_n; i++ )
    {
      dis[i] = output_nodes[i+1];
      
    } // end for shifting into new array

    return ( dis );

  } // END public double[] distributionForInstance ( Instance inst ) 
    //    throws Exception
  
//********************************************************//

  /**
   * Loads the instance into the input nodes of the net
   *
   */

  private void LoadInput ( Instance inst ) throws Exception
  {    
    // need counters
    int i, j = 0;

    // testing
//    System.out.println ( "\nFor instance: " + inst.toString() + "\n" );

    // so, remember to skip the class attribute of the 
    // instance
    for ( i = 1; i <= input_n; i++ )
    {
      // check if we are at the class attribute
      if ( j != inst.classIndex() )
      {
	// if this is a binomial attribute (not numeric)
	// then we should adjust the values here as we do
	// for the output value loading, have to rely on
	// plain numeric data rarely equaling excatly 0 or 1
	// since it would be very difficult to check the original 
	// data set etc.
	if ( inst.value ( j ) == 0 )
	{
	  input_nodes[i] =  0.1;
	}
	else if ( inst.value ( j ) == 1 )
	{
	  input_nodes[i] = 0.9;
	}
	else
	{
	  input_nodes[i] = inst.value ( j );
	}

      } // end if not class attribute
      
      // go to next attribute
      j++;
      
    } // end for each input node

    // just testing
//    System.out.print ( "Loaded Input" );
//    for ( i = 1; i <= input_n; i++ )
//    {
//      System.out.print ( " : " + input_nodes[i] );
//    }

  } // END private void LoadInput ( Instance inst )

//********************************************************//

  /**
   * Loads the instance into the target array of the net
   *
   */

  private void LoadTarget ( Instance inst ) throws Exception
  {    
    // we want the target value to be 0.9 when
    // this node corresponds to the index of
    // the nominal value of the class attribute
    // and 0.1 otherwise
    int i = 0;

    // set all values in target to 0.1
    for ( i = 1; i <= output_n; i++ )
    {
      target[i] = 0.1;

    } // end set all values to 0.1
    
    // set the value of target at index corresponding
    // to the value of the class attribute to 0.9

    // remember to add 1 to the index since our target
    // vector is referenced starting from 1
    target[ (int)(inst.classValue()) + 1 ] = 0.9;
    
    // just testing
//    System.out.print ( "Loaded target" );
//    for ( i = 1; i <= output_n; i++ )
//    {
//      System.out.print ( " : " + target[i] );
//    }
//    System.out.println ( ": for value: " + inst.classValue() + "\n" );

  } // END private void LoadTarget ( Instance inst )

//********************************************************//

  /**
   * Returns textual description of classifier.
   */

  public String toString()
  {
    // the string to return
    String s = new String();

    // need counters
    int i, j = 0;

    s += "Back Propagation Neural Network:\n";
    s += "Numer Input Nodes: " + input_n + "\n";
    s += "Number Hidden Nodes: " + hidden_n + "\n";
    s += "Number Output Nodes: " + output_n + "\n";

    // add the options
    for ( i = 0; i < getOptions().length; i++ )
    {
      s += "\t" + getOptions()[i] + "\t" + 
	getOptions()[i+1] + "\n";

      i++;
    }
	     
    // now for the weights

    // input weights first
    s += "Input Weights:\n";

    for ( i = 1; i <= input_n; i++ )
    {
      for ( j = 1; j <= hidden_n; j++ )
      {
	// print out this weight
	s += "\t" + input_weights[i][j];

      } // end for each hidden node

      // print the name of the attribute this
      // input node corresponds to
      s += "\t" + m_Train.attribute ( i - 1 ).name();

      // start a new line
      s += "\n";

    } // end for each input node

    // hidden weights next
    s += "Hidden Weights:\n";

    for ( i = 1; i <= hidden_n; i++ )
    {
      for ( j = 1; j <= output_n; j++ )
      {
	// print out this weight
	s += "\t" + hidden_weights[i][j];

      } // end for each output node

      // start a new line
      s += "\n";

    } // end for each hidden node

    // print out the classification for each output node
    try
    {
      for ( j = 1; j <= output_n; j++ )
      {
	s += "\t\t" + m_Train.classAttribute().value ( j - 1 );
      }
    }
    catch ( Exception ex )
    {
      System.err.println ( "Training data might not have class" );
    }

    s += "\n";

    // all done
    return ( s );

  } // END public String toString() 

//********************************************************//

  /**
   * Constructs a string for the contents of a 2D array
   */

  //  private String 
//********************************************************//

  /**
   * Get the value of NumIterations.
   *
   * @return Value of NumIterations.
   */
  
  public int getNumIterations() 
  {  
    return m_NumIterations;

  } // END public int getNumIterations()

//********************************************************//
  
  /**
   * Set the value of NumIterations.
   *
   * @param v  Value to assign to NumIterations.
   */
  
  public void setNumIterations ( int v ) 
  {  
    m_NumIterations = v;

  } // END public void setNumIterations ( int v )
  
//********************************************************//
  
  /**
   * Get the value of learning rate.
   *
   * @return Value of learning rate.
   */

  public double getLearningRate() 
  {  
    return m_LearningRate;

  } // END public double getLearningRate()

//********************************************************//
  
  /**
   * Set the value of learning rate
   *
   * @param v  Value to assign to learning rate
   */

  public void setLearningRate ( double v )
  {  
    m_LearningRate = v;

  } // END public void setLearningRate ( double v )
  
//********************************************************//
  
  /**
   * Get the value of momentum
   *
   * @return Value of momentum
   */

  public double getMomentum() 
  {  
    return m_Momentum;

  } // END public double getMomentum()

//********************************************************//
  
  /**
   * Set the value of momentum
   *
   * @param v  Value to assign to momentum
   */

  public void setMomentum ( double v )
  {  
    m_Momentum = v;

  } // END public void setMomentum ( double v )
  
//********************************************************//
  
  /**
   * Get the value of Seed.
   *
   * @return Value of Seed.
   */

  public int getSeed() 
  {  
    return m_Seed;

  } // END public int getSeed()

//********************************************************//
  
  /**
   * Set the value of Seed.
   *
   * @param v  Value to assign to Seed.
   */
  
  public void setSeed ( int v )
  {    
    m_Seed = v;
    
    // update seed in random
    random.setSeed ( m_Seed );

  } // END public void setSeed ( int v )

//********************************************************//
  
  /** 
   * Feed Forward pushes an instance thourhg the neural net
   * resulting in a classification for the current input.
   * The result is stored in the output nodes of the net.
   */
  
  private void FeedForward () 
    throws Exception 
  {
    // print out the net before
//    System.out.println ( "Input before:\n" + toString() );

    LayerForward ( input_nodes, hidden_nodes,
		   input_weights, input_n, hidden_n );

    // print out the net after
//    System.out.println ( "Input after:\n" + toString() );

    LayerForward ( hidden_nodes, output_nodes,
		   hidden_weights, hidden_n, output_n );

    // print out the net after'
//    System.out.println ( "Hidden after:\n" + toString() );

  } // END private void FeedForward() 
    //    throws Exception
  
//********************************************************//

  /**
   * Feeds the weighted values in the first level nodes
   * to the second level nodes.
   *
   */

  void LayerForward ( double[] l1, double[] l2, 
		      double[][]conn, int n1, int n2 )
  {
    double sum;
    int j, k;
    
    /*** Set up thresholding unit ***/
    l1[0] = 1.0;
    
    /*** For each unit in second layer ***/
    for ( j = 1; j <= n2; j++ )
    {
      /*** Compute weighted sum of its inputs ***/
      sum = 0.0;
      
      /** for each unit in layer 1 **/
      for ( k = 0; k <= n1; k++ ) 
      {
	/** corresponds to Figure 4.6, sum of weighted inputs to
	 ** unit. conn[k][j] contains the weight of the connection
	 ** from the j-th first layer unit to the k-th second
	 ** layer unit. l1[k] contains the input from the k-th 
	 ** unit from layer 1.
	 **/
	sum += conn[k][j] * l1[k];
      }
      
      /** this computes the sigma as in Formula 4.12 (Figure 4.6) **/
      l2[j] = squash ( sum );
    }
    
  } // END void LayerForward ( double[] l1, double[] l2, 
    //			       double[][]conn, int n1, int n2 )
  
//********************************************************//

  /* initializes weights in a 2D array to random values 
   * -1.0 to 1.0
   */
  
  private void Randomize_Weights ( double[][] w, int m, int n )
  {
    int i, j;
    
    for ( i = 0; i <= m; i++ )
    {
      for ( j = 0; j <= n; j++ )
      {
	w[i][j] = ( ( random.nextDouble() * 2.0 ) - 1.0 );
      }
    }
    
  } // END Randomize_Weights ( w, m, n )

//********************************************************//

  /* initializes weights in a 2D array to zeroes
   */
  
  private void Zero_Weights ( double[][] w, int m, int n )
  {
    int i, j;
    
    for ( i = 0; i <= m; i++ )
    {
      for ( j = 0; j <= n; j++ )
      {
	w[i][j] = 0.0;
      }
    }
    
  } // END Zero_Weights ( w, m, n )

//********************************************************//
 
  /**
   * The squashing function.  Currently, it's a sigmoid. 
   */
  
  /** as in formula 4.12 and illustrated in figure 4.6 from text **/
  
  private double squash ( double x )
  {    
    /** this formula looks correct **/
    return ( 1.0 / ( 1.0 + Math.exp ( -x ) ) );
    
  } // END private double squash ( double x )

/******************************************************************/

  /**
   * Main method.
   */
  
  public static void main ( String[] argv )  
  {
    try 
    {
	System.out.println ( Evaluation.evaluateModel ( new BackPropagation(), 
						      argv ) );
    } 
    catch ( Exception e )
    {
      System.out.println ( e.getMessage() );
    }

  } // END public static void main ( String[] argv )
  
//********************************************************//

//********************************************************//

} // END public class BackPropagation extends DistributionClassifier 
  //	 implements OptionHandler
    
  
