/*
 * Adapted and comments :
 *
 * Keith A. Pray
 * October 2000
 */

/*
 *    NaiveBayesSimple.java
 *    Copyright (C) 1999 Eibe Frank
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

//********************************************************//

import weka.classifiers.DistributionClassifier;
import weka.classifiers.Classifier;
import weka.classifiers.Evaluation;
import weka.filters.DiscretizeFilter;
import weka.filters.NormalizationFilter;
import weka.filters.Filter;

import java.io.*;
import java.util.*;
import weka.core.*;

//********************************************************//

/**
 * Class for building and using a simple Naive Bayes 
 * classifier.
 * Numeric attributes are modelled by a normal distribution. 
 * For more information, see<p>
 *
 * Richard Duda and Peter Hart (1973).<i>Pattern
 * Classification and Scene Analysis</i>. Wiley, New York.
 *
 * @author Eibe Frank (eibe@cs.waikato.ac.nz)
 * @version $Revision: 1.6 $ 
*/

public class NaiveBayesSimple extends DistributionClassifier
{

//********************************************************//
// ----- Data Members -----
//********************************************************//

  /** All the counts for nominal attributes. */
  private double [][][] m_Counts;
  
  /** The means for numeric attributes. */
  private double [][] m_Means;

  /** The standard deviations for numeric attributes. */
  private double [][] m_Devs;

  /** The prior probabilities of the classes. */
  private double [] m_Priors;

  /** The instances used for training. */
  private Instances m_Instances;

  /** Constant for normal distribution. */
  private static double NORM_CONST = Math.sqrt ( 2 * 
						 Math.PI );

//********************************************************//
// ----- Methods-----
//********************************************************//

  /**
   * Generates the classifier.
   *
   * @param instances 
   * set of instances serving as training data 
   *
   * @exception
   * Exception if the classifier has not been generated
   * successfully
   */

  public void buildClassifier ( Instances instances ) 
    throws Exception 
  {
    // index for the attributes
    int attIndex = 0;

    // the sum of course... but sum of what?
    double sum;
    
    // we won't handle string attributes
    if ( instances.checkForStringAttributes() ) 
    {
      throw new Exception ( "Can't handle string attributes!" );
    }
    
    // won't handle numeric class attributes
    if ( instances.classAttribute().isNumeric() ) 
    {
      throw new Exception ( "Naive Bayes: Class is numeric!" );
    }
    
    // make the instances with which we will build
    // out simple Bayes classifier
    m_Instances = new Instances ( instances, 0 );

    // ***** Reserve space for... ***** //

    // the counts of attribute values
    m_Counts = new double [ instances.numClasses() ]
      [ instances.numAttributes() - 1 ][ 0 ];

    // need a mean for each attribute
    m_Means = new double [ instances.numClasses() ]
      [ instances.numAttributes() - 1 ];

    // need standard deviation too
    m_Devs = new double [ instances.numClasses() ]
      [ instances.numAttributes() - 1 ];

    // have to keep track of previous probabilities
    m_Priors = new double [ instances.numClasses() ];

    // get an enumeration of the attributes so we can
    // traverse them, oh yeah!
    Enumeration enum = instances.enumerateAttributes();

    // traverse away...
    while ( enum.hasMoreElements() ) 
    {
      // get the next attribute
      Attribute attribute = (Attribute) enum.nextElement();

      // if it is nominal (discrete), make as many places as
      // values it can have
      if ( attribute.isNominal() ) 
      {
	for (int j = 0; j < instances.numClasses(); j++) 
	{
	  m_Counts [ j ][ attIndex ] = new double [ attribute.numValues() ];
	}
      } 
      else 
      {
	// otherwise we just make 1 space for each
	for ( int j = 0; j < instances.numClasses(); j++ ) 
	{
	  m_Counts [ j ][ attIndex ] = new double[1];
	}

      } // end while there are more attributes

      // keep track of which attribute we are on
      attIndex++;
    }
    
    // ***** Compute counts and sums ***** //

    // need an enumeration of the instances
    Enumeration enumInsts = instances.enumerateInstances();

    // traverse baby!
    while ( enumInsts.hasMoreElements() )
    {
      // get the next instance
      Instance instance = (Instance) enumInsts.nextElement();
      
      // if this instance has a value for the class 
      // attribute, use it
      if ( ! instance.classIsMissing() ) 
      {
	// get the attributes for this instance
	Enumeration enumAtts = instances.enumerateAttributes();
	attIndex = 0;

	// for each attribute
	while ( enumAtts.hasMoreElements() ) 
	{
	  // get the next attribute
	  Attribute attribute = (Attribute) enumAtts.nextElement();

	  // if the attribute has a value
	  if ( ! instance.isMissing ( attribute ) ) 
	  {
	    // if the attribute is nominal (discrete)
	    // add 1 to the count for the value
	    if ( attribute.isNominal() ) 
	    {
	      m_Counts [ (int)instance.classValue() ][ attIndex ]
		[ (int)instance.value ( attribute ) ]++;

	    } // end if attribute is nominal (discrete)
	    else 
	    {
	      // otherwise
	      m_Means [ (int)instance.classValue() ][ attIndex ] +=
		instance.value ( attribute );

	      m_Counts [ (int)instance.classValue() ][ attIndex ][0]++;

	    } // end else attribute is numeric

	  } // end if class attribute value present
	  
	  // keep track of which attribute we are on
	  attIndex++;

	} // end while there are more attributes

	m_Priors [ (int)instance.classValue() ]++;

      } // end if class value not missing

    } // end while there are more instances
    
    // ***** Compute means ***** //
    
    // get the attributes for these instances
    Enumeration enumAtts = instances.enumerateAttributes();
    attIndex = 0;

    // for each attribute, compute the mean
    while ( enumAtts.hasMoreElements() ) 
    {
      // get the next attribute
      Attribute attribute = (Attribute) enumAtts.nextElement();

      // only if the attribute is numeric of course
      if ( attribute.isNumeric() ) 
      {
	for ( int j = 0; j < instances.numClasses(); j++ ) 
	{
	  if ( m_Counts [ j ][ attIndex ][ 0 ] < 2 ) 
	  {
	    throw new Exception ( "attribute " + attribute.name() +
				  ": less than two values for class " +
				  instances.classAttribute().value ( j ) );

	  } // end if there are less than 2

	  m_Means [ j ][ attIndex ] /= m_Counts [ j ][ attIndex ][ 0 ];

	} // end for each class value

      } // end if attribute is numeric

      // keep track of which attribute we are on
      attIndex++;

    } // end while there are more attributes
    
    // ***** Compute standard deviations ***** //
    enumInsts = instances.enumerateInstances();

    // for each instance (training example)
    // we first compute the differences
    while ( enumInsts.hasMoreElements() ) 
    {
      // get the next instance
      Instance instance = 
	(Instance) enumInsts.nextElement();

      // use if the class attribute value is present
      if ( ! instance.classIsMissing() ) 
      {
	// get the attributes of this instance
	enumAtts = instances.enumerateAttributes();
	attIndex = 0;

	// for each attribute
	while ( enumAtts.hasMoreElements() ) 
	{
	  Attribute attribute = (Attribute) enumAtts.nextElement();

	  // only if the attribute value is present
	  if ( ! instance.isMissing ( attribute ) ) 
	  {
	    // only if the attribute is numeric of course
	    if ( attribute.isNumeric() ) 
	    {
	      m_Devs [ (int)instance.classValue() ][ attIndex ] +=
		( m_Means [ (int)instance.classValue() ][ attIndex ]-
		 instance.value ( attribute ) ) *
		( m_Means [ (int)instance.classValue() ][ attIndex ]-
		 instance.value ( attribute ) );

	    } // end if attribute is numeric

	  } // end if attribute value not missing

	  // keep track of which attribute we are on
	  attIndex++;

	} // end while there are more attributes

      } // end is class attribute value is present
    }

    enumAtts = instances.enumerateAttributes();
    attIndex = 0;

    // now we calculate the average of the differences
    while ( enumAtts.hasMoreElements() ) 
    {
      Attribute attribute = (Attribute) enumAtts.nextElement();

      // do it for only numeric attributes of course
      if ( attribute.isNumeric() ) 
      {
	// for each class
	for ( int j = 0; j < instances.numClasses(); j++ ) 
	{
	  if ( m_Devs [ j ][ attIndex ] <= 0 ) 
	  {
	    throw new Exception ( "attribute " + attribute.name() +
				  ": standard deviation is 0 for class " +
				  instances.classAttribute().value ( j ) );

	  } // end if standard deviation is 0 (weird when that happens)
	  else 
	  {
	    // calculate the std. dev.
	    m_Devs [ j ][ attIndex ] /= m_Counts [ j ][ attIndex ][ 0 ] - 1;
	    m_Devs [ j ][ attIndex ] = Math.sqrt ( m_Devs [ j ][ attIndex ] );

	  } // end else std. dev. not 0

	} // end for each class

      } // if attribute is numeric

      // keep track of while attribute we are on
      attIndex++;

    } // end while there are more attributes
    
    // ***** Normalize counts ***** //
    enumAtts = instances.enumerateAttributes();
    attIndex = 0;

    while ( enumAtts.hasMoreElements() ) 
    {
      Attribute attribute = (Attribute) enumAtts.nextElement();

      // we want to work with the sum of counts = 1
      if ( attribute.isNominal() ) 
      {
	for ( int j = 0; j < instances.numClasses(); j++ ) 
	{
	  sum = Utils.sum ( m_Counts [ j ][ attIndex ] );

	  for ( int i = 0; i < attribute.numValues(); i++ ) 
	  {
	    m_Counts [ j ][ attIndex ][ i ] =
	      (m_Counts [ j ][ attIndex ][ i ] + 1 ) 
	      / ( sum + (double)attribute.numValues() );

	  } // end for each value for this attribute

	} // end for each class

      } // end if attribute is nominal (discrete)

      attIndex++;

    } // end while there are more attributes
    
    // ***** Normalize priors ***** //
    sum = Utils.sum ( m_Priors );

    for ( int j = 0; j < instances.numClasses(); j++ )
    {
      m_Priors [ j ] = ( m_Priors [ j ] + 1 ) 
	/ ( sum + (double)instances.numClasses() );
    }

    // All done :)

  } // END public void buildClassifier ( Instances instances ) 
    //    throws Exception

//********************************************************//

  /**
   * Calculates the class membership probabilities for
   * the given test instance.
   *
   * @param instance 
   * the instance to be classified
   *
   * @return
   * predicted class probability distribution
   *
   * @exception
   * Exception if distribution can't be computed
   */

  public double[] distributionForInstance ( Instance instance ) 
    throws Exception 
  {
    // need a probability for each class
    double [] probs = new double [ instance.numClasses() ];

    int attIndex;

    // for each class
    for ( int j = 0; j < instance.numClasses(); j++ ) 
    {
      // initialize the probability to 1
      // then we multiply this by the prob for the normalized
      // attribute value
      probs [ j ] = 1;

      // get the attributes of the instance
      Enumeration enumAtts = instance.enumerateAttributes();

      attIndex = 0;

      // for each attribute
      while ( enumAtts.hasMoreElements() ) 
      {
	Attribute attribute = (Attribute) enumAtts.nextElement();

	// if the attribute has a value
	if ( ! instance.isMissing ( attribute ) ) 
	{
	  // if the attribute is nominal (discrete)
	  if ( attribute.isNominal() ) 
	  {
	    // mulitply the probability for this attribute
	    // to the normalized count obtained for this value
	    // when building the classifier
	    probs [ j ] *= m_Counts [ j ][ attIndex ]
	      [ (int)instance.value ( attribute ) ];
	  } 
	  else 
	  {
	    // otherwise, multiply by the density distribution
	    // for this numeric value (normal distribution) and
	    // mean obtained while building this classifier
	    probs [ j ] *= normalDens ( instance.value ( attribute ),
					m_Means [ j ][ attIndex ],
					m_Devs [ j ][ attIndex ] );

	  } // end else attribute is numeric, use mean and
	    // density distribution with value

	} // end if attribute has a value

	// keep track of which attribute is next
	attIndex++;

      } // end while there are more attributes

      // factor in the prior probabilities
      probs [ j ] *= m_Priors [ j ];

    } // end for each class value
    
    // Normalize probabilities
    Utils.normalize ( probs );

    // All done
    return ( probs );

  } // END public double[] distributionForInstance ( Instance instance ) 
    //    throws Exception

//********************************************************//

  /**
   * Returns a description of the classifier.
   *
   * @return
   * a description of the classifier as a string.
   */

  public String toString() 
  {
    // did we successfully build a classifier yet?
    if ( m_Instances == null ) 
    {
      return ( "Naive Bayes (simple): No model built yet." );

    } // end no classifier built yet

    try 
    {
      StringBuffer text = new StringBuffer ( "Naive Bayes (simple)" );
      int attIndex;
      
      for ( int i = 0; i < m_Instances.numClasses(); i++ ) 
      {
	text.append( "\n\nClass " + m_Instances.classAttribute().value ( i ) 
		     + ": P(C) = " 
		     + Utils.doubleToString ( m_Priors [ i ], 10, 8 )
		     + "\n\n" );

	Enumeration enumAtts = m_Instances.enumerateAttributes();

	attIndex = 0;

	while ( enumAtts.hasMoreElements() ) 
	{
	  Attribute attribute = (Attribute) enumAtts.nextElement();

	  text.append ( "Attribute " + attribute.name() + "\n" );

	  if ( attribute.isNominal() ) 
	  {
	    for ( int j = 0; j < attribute.numValues(); j++ ) 
	    {
	      text.append ( attribute.value ( j ) + "\t" );

	      text.append( Utils.
			   doubleToString ( m_Counts [ i ][ attIndex ][ j ],
					    10, 8 )
			  + "\n" );
	    }

	    text.append ( "\n" );

	  } 
	  else 
	  {
	    text.append ( "Mean: " + Utils.
			  doubleToString ( m_Means [ i ][ attIndex ], 
					   10, 8 ) + "\t" );

	    text.append ( "Standard Deviation: " 
			+ Utils.doubleToString ( m_Devs [ i ][ attIndex ], 
						 10, 8 ) );
	  }

	  text.append ( "\n\n" );
	  attIndex++;
	}
      }
      
      return text.toString();
    } 
    catch ( Exception e ) 
    {
      return ( "Can't print Naive Bayes classifier!" );
    }

  } // END public String toString()

//********************************************************//

  /**
   * Density function of normal distribution.
   */

  private double normalDens ( double x, double mean, 
			      double stdDev ) 
  {
    
    double diff = x - mean;
    
    return ( 1 / ( NORM_CONST * stdDev ) ) 
      * Math.exp ( - ( diff * diff / 
		       ( 2 * stdDev * stdDev ) ) );

  } // END private double normalDens ( double x, double mean, 
    //			      double stdDev )

//********************************************************//

  /**
   * Main method for testing this class.
   *
   * @param argv the options
   */

  public static void main ( String [] argv ) 
  {
    Classifier scheme;

    try 
    {
      scheme = new NaiveBayesSimple();

      System.out.println ( Evaluation.evaluateModel ( scheme, argv ) );
    } 
    catch ( Exception e ) 
    {
      System.err.println ( e.getMessage() );
    }

  } // END public static void main ( String [] argv )

//********************************************************//

} // END public class NaiveBayesSimple 
  //			extends DistributionClassifier

//********************************************************//
