// --*- C++ -*------x---------------------------------------------------------
// $Id: KnnNetTrainer.h,v 1.1.1.1 2006/07/03 14:43:20 bindewae Exp $
//
// Class:           KnnNode
// 
// Base class:      -
//
// Derived classes: - 
//
// Author:          Eckart Bindewald
//
// Description:     prediction using K-nearest neighbor method
// 
// Reviewed by:     -
// -----------------x-------------------x-------------------x-----------------

#ifndef __KNN_NET_TRAINER_H__
#define __KNN_NET_TRAINER_H__

// Includes

#include <iostream>
#include <ClassifierBase.h>
#include <debug.h>
#include <Vec.h>

/** This is the first sentence summarizing the classes' purpose.

    Here comes a *detailed* description of this class. This
    description may consist of more than one sentence.

    @author Eckart Bindewald
    @see    Modern Applied Statics with S / Ripley
    @review - */
class KnnNetTrainer {

public:
  KnnNetTrainer();

  KnnNetTrainer(const KnnNetTrainer& orig);

  virtual ~KnnNetTrainer();

  /* OPERATORS */

  /** Assigment operator. */
  KnnNetTrainer& operator = (const KnnNetTrainer& orig);

  friend ostream& operator << (ostream& os, const KnnNetTrainer& rval);

  friend istream& operator >> (istream& is, KnnNetTrainer& rval);
 

  /* PREDICATES */

  /** Is current state valid? */
  virtual bool isValid() const { return (data.size() > 0); }
  
  /** How big is object? */
  virtual unsigned int size() const { return data.size(); }

  /** return dimension of each data vector */
  virtual unsigned int getDim() const {
    if (data.size() == 0) {
      return 0;
    }
    return data[0].size();
  }

  virtual unsigned int getNumClasses() const { return numClasses; }

  virtual const Vec<Vec<double> >& getData() const { return data; }

  /** returns data rows which belong to class dataClass */
  virtual Vec<Vec<double> > getData(unsigned int dataClass) const;

  /** returns indices of data rows which belong to class dataClass */
  virtual Vec<unsigned int> getDataIndices(unsigned int dataClass) const;

  /** returns data row n */
  virtual const Vec<double>& getDataRow(unsigned int n) const { return data[n]; }

  /** return class of data row n */
  virtual unsigned int getDataRowClass(unsigned int n) const { return dataClasses[n]; }

  /** return class of data row n */
  virtual const Vec<unsigned int>& getDataClasses() const { return dataClasses; }

  virtual void writeLevelTrainVectors(ostream& writeFile, int numEntries,
				      const ClassifierBase& knnNet) const;

  virtual double estimateAccuracy(const ClassifierBase& knnNet) const;

  /* MODIFIERS */

  /** read input data */
  virtual void readData(istream& is);

  /** read input data */
  virtual void readQueryData(istream& is);

  virtual void setData(const Vec<Vec<double> >& mtx, const Vec<unsigned int>& dClasses,
		       unsigned int _nClass) { 
    PRECOND(mtx.size() == dClasses.size());
    data = mtx; dataClasses = dClasses; numClasses = _nClass;
  }
  
  virtual void setNumClasses(unsigned int n) { numClasses = n; }

  virtual void setVerboseLevel(int n) { verboseLevel = n; }

  void optimize(ClassifierBase& knnNet, unsigned int numSteps,
		unsigned int numNodeSteps,
		int verboseLevel);
  
protected:
  /* OPERATORS  */
  /* PREDICATES */
  /* MODIFIERS  */
  void copy(const KnnNetTrainer& other);

private:
  /* OPERATORS  */
  /* PREDICATES */

  /* MODIFIERS  */

private:
  
  /* PRIVATE ATTRIBUTES */

  unsigned int numClasses;

  int verboseLevel;

  Vec<Vec<double> > data;

  Vec<unsigned int> dataClasses;

};

#endif /* __KNN_NODE_H__ */

