/*
 * Adapted from IB1 :
 *
 * This implements the IB k algorithm from the book.
 *
 * Keith A. Pray
 * November 2000
 */

/*
 *    IB1.java
 *    Copyright (C) 1999 Stuart Inglis,Len Trigg,Eibe Frank
 *
 *    This program is free software; you can redistribute it and/or modify
 *    it under the terms of the GNU General Public License as published by
 *    the Free Software Foundation; either version 2 of the License, or
 *    (at your option) any later version.
 *
 *    This program is distributed in the hope that it will be useful,
 *    but WITHOUT ANY WARRANTY; without even the implied warranty of
 *    MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
 *    GNU General Public License for more details.
 *
 *    You should have received a copy of the GNU General Public License
 *    along with this program; if not, write to the Free Software
 *    Foundation, Inc., 675 Mass Ave, Cambridge, MA 02139, USA.
 */

//********************************************************//

import weka.classifiers.Classifier;
import weka.classifiers.UpdateableClassifier;
import weka.classifiers.Evaluation;

import java.io.*;
import java.util.*;
import java.lang.Math;

import weka.core.*;

//********************************************************//

/**
 */

public class IBk extends Classifier
    implements UpdateableClassifier 
{

//********************************************************//
// ----- Data Members -----
//********************************************************//
  
  /** The training instances used for classification. */
  private Instances m_Train;

  /** The minimum values for numeric attributes. */
  private double [] m_MinArray;

  /** The maximum values for numeric attributes. */
  private double [] m_MaxArray;

  /** The number of neighbors to use for classification */
  private int k = 7;

//********************************************************//
// ----- Methods-----
//********************************************************//

  /**
   * Generates the classifier.
   *
   * @param instances
   *  set of instances serving as training data 
   *   
   * @exception 
   * Exception if the classifier has not been generated 
   * successfully
   */
  
  public void buildClassifier ( Instances instances ) 
    throws Exception
  {  
    if ( instances.checkForStringAttributes() ) 
    {
      throw new Exception ( "Can't handle string attributes!" );
    }

    // Throw away training instances with missing class
    m_Train = new Instances ( instances, 0, instances.numInstances() );
    m_Train.deleteWithMissingClass();

    m_MinArray = new double [ m_Train.numAttributes() ];
    m_MaxArray = new double [ m_Train.numAttributes() ];
    
    for ( int i = 0; i < m_Train.numAttributes(); i++ )
    {
      m_MinArray[i] = m_MaxArray[i] = Double.NaN;
    }

    Enumeration enum = m_Train.enumerateInstances();
    
    while ( enum.hasMoreElements() )
    {
      updateMinMax ( (Instance) enum.nextElement() );
    }

  } // END buildClassifier ( Instances instances ) 
    //	 throws Exception

//********************************************************//

  /**
   * Updates the classifier.
   *
   * @param instance
   * the instance to be put into the classifier
   *
   * @exception
   * Exception if the instance could not be included 
   * successfully
   */

  public void updateClassifier ( Instance instance )
    throws Exception
  {
  
    if ( m_Train.equalHeaders ( instance.dataset() ) == false )
    {
      throw new Exception ( "Incompatible instance types" );
    }

    if ( instance.classIsMissing() )
    {
      return;
    }

    m_Train.add ( instance );
    updateMinMax ( instance );

  } // END updateClassifier ( Instance instance )

//********************************************************//

  /**
   * Classifies the given test instance.
   *
   * @param instance 
   * the instance to be classified
   *
   * @return
   * the predicted class for the instance 
   *
   * @exception
   * Exception if the instance can't be classified
   */

  public double classifyInstance ( Instance instance ) 
    throws Exception
  {  
    // make sure there are enough training instances
    // without them, we can't do anything
    if ( m_Train.numInstances() < k ) 
    {
      throw new Exception ( "Not enough training instances!" );
    }

    double distance, minmaxDistance = Double.MAX_VALUE, classValue = 0;

    updateMinMax ( instance );

    // keep track of the last k classes
    double [] lastk = new double [ k ];

    // keep track of the distances of the last k classes
    double [] lastkDistance = new double [ k ];

    int tempIndex =0;
    
    // init the Distance
    for ( tempIndex = 0; tempIndex < k; tempIndex++ )
    {
      lastkDistance [ tempIndex ] = Double.MAX_VALUE;
    }	

    Enumeration enum = m_Train.enumerateInstances();

    while ( enum.hasMoreElements() ) 
    {
      Instance trainInstance = (Instance) enum.nextElement();

      if ( ! trainInstance.classIsMissing() ) 
      {
	distance = distance ( instance, trainInstance );

	// see if this is one of our least
	if ( distance < minmaxDistance ) 
	{
	  // get the index of the minmax
	  tempIndex = 0;

	  while ( tempIndex < k )
	  {
	    if ( minmaxDistance == lastkDistance [ tempIndex ] )
	    {
	      break;
	    }

	    tempIndex++;
	  }

	  // put the new distance value there
	  lastkDistance [ tempIndex ] = distance;

	  // add class to lastk array
	  lastk [ tempIndex ] = trainInstance.classValue();

	  // reset minmax
	  minmaxDistance = 0;

	  // get value of the new minmax distance
	  for ( tempIndex = 0; tempIndex < k; tempIndex++ )
	  {
	    minmaxDistance = Math.max ( minmaxDistance, 
					lastkDistance [ tempIndex ] );
	  } // end for

	} // end if this distance less than minmax

      } // end if class not missing

    } // end while

    // ok, now we have to use the last k to classify
    // average the values
    double avg = 0;

    for ( tempIndex = 0; tempIndex < k; tempIndex++ )
    {
      avg += lastk [ tempIndex ];
    }

    avg = avg / k;

    //    System.out.print ( Math.round ( avg ) + ", " );
    
    // round the average to the nearest int
    return ( Math.round ( avg ) );

  } // END classifyInstance ( Instance instance )

//********************************************************//

  /**
   * Returns a description of this classifier.
   *
   * @return
   * a description of this classifier as a string.
   */

  public String toString()
  {
    // string to return
    String s;
    
    s = "IBk classifier: k = " + k;

    return ( s );	

  } // END toString()

//********************************************************//

  /**
   * Calculates the distance between two instances
   *
   * @param first
   * the first instance
   *
   * @param second
   * the second instance
   *
   * @return
   * the distance between the two given instances
   */

  private double distance ( Instance first, Instance second )
  {
    double diff, distance = 0;

    for ( int i = 0; i < m_Train.numAttributes(); i++ ) 
    { 
      if ( i == m_Train.classIndex() ) 
      {
	continue;
      }

      if ( m_Train.attribute(i).isNominal() ) 
      {
	// If attribute is nominal
	if ( first.isMissing(i) || second.isMissing(i) ||
	     ( (int)first.value(i) != (int)second.value(i) ) ) 
	{
	  distance += 1;
	}

      }
      else 
      {
	// If attribute is numeric
	if ( first.isMissing(i) || second.isMissing(i) )
	{
	  if ( first.isMissing(i) && second.isMissing(i) )
	  {
	    diff = 1;
	  } 
	  else
	  {
	    if ( second.isMissing(i) ) 
	    {
	      diff = norm ( first.value(i), i );
	    } 
	    else
	    {
	      diff = norm ( second.value(i), i );
	    }

	    if ( diff < 0.5 ) 
	    {
	      diff = 1.0 - diff;
	    }
	  }
	} 
	else 
	{
	  diff = norm ( first.value(i), i ) - 
	    norm ( second.value(i), i );
	}

	distance += diff * diff;
      }
    }
    
    return ( distance );

  } // END distance ( Instance first, Instance second )

//********************************************************//
    
  /**
   * Normalizes a given value of a numeric attribute.
   *
   * @param x
   * the value to be normalized
   *
   * @param i
   * the attribute's index
   */

  private double norm ( double x, int i )
  {
    if ( Double.isNaN ( m_MinArray[i] )
	 || Utils.eq ( m_MaxArray[i], m_MinArray[i] ) )
    {
      return 0;
    } 
    else
    {
      return ( x - m_MinArray[i] ) / 
	( m_MaxArray[i] - m_MinArray[i] );
    }

  } // END norm ( double x, int i )

//********************************************************//

  /**
   * Updates the minimum and maximum values for all the 
   * attributes based on a new instance.
   *
   * @param instance
   * the new instance
   */

  private void updateMinMax ( Instance instance )
  {  
    for ( int j = 0; j < m_Train.numAttributes(); j++ )
    {
      if ( ( m_Train.attribute(j).isNumeric() ) && 
	   ( ! instance.isMissing(j) ) )
      {
	if ( Double.isNaN ( m_MinArray[j] ) ) 
	{
	  m_MinArray[j] = instance.value(j);
	  m_MaxArray[j] = instance.value(j);
	} 
	else
	{
	  if ( instance.value(j) < m_MinArray[j] )
	  {
	    m_MinArray[j] = instance.value(j);
	  } 
	  else
	  {
	    if ( instance.value(j) > m_MaxArray[j] )
	    {
	      m_MaxArray[j] = instance.value(j);
	    }
	  }
	}
      }
    }
  } // END updateMinMax ( Instance instance )

//********************************************************//

  /**
   * Main method for testing this class.
   *
   * @param argv 
   * should contain command line arguments for evaluation
   * (see Evaluation).
   */

  public static void main ( String [] argv )
  {
    try 
    {
      System.out.println ( Evaluation.evaluateModel ( new IBk(), 
						      argv ) );
    } 
    catch ( Exception e )
    {
      System.err.println ( e.getMessage() );
    }

  } // END main ( String [] argv )

//********************************************************//

} // END public class IBk extends Classifier 
  //	implements UpdateableClassifier




