#include <string>
#include <Physical.h>
#include <CorrectedMutualInformationScorer.h>
#include <utility>  // used for pair class
#include <generalNumerics.h>
#include <Random.h>
#include <aligncolor_help.h>
#include <StringTools.h>

// #define DEBUG_VERBOSE

const double LOG_2 = log(2.0);

/** returns (the error corrected) mutual information and an estimate of the small sample noise as a standard deviation */
Physical
CorrectedMutualInformationScorer::computeMutualInformation(const string& s1Orig, const string& s2Orig) const 
{
  // first remove all rows with common gaps
  pair<string, string> stringPair = cleanAliColumns(s1Orig, s2Orig);
  const string& s1 = stringPair.first;
  const string& s2 = stringPair.second;
  AlignmentColumn s1Col(s1, alphabet);
  AlignmentColumn s2Col(s2, alphabet);

  Physical informationDouble = computeInformationClean(s1Col, s2Col);
  Physical informationSingle1 = computeInformation(s1Col);
  Physical informationSingle2 = computeInformation(s2Col);
  // returns Rij - Ri -Rj. Uses error propagation!
  Physical result = informationDouble - informationSingle1 
    - informationSingle2;
#ifdef DEBUG_VERBOSE
  cout << "Debug output; Rij, Ri, Rj: " << informationDouble << " " 
       << informationSingle1 << " " << informationSingle2 << " result: "
       << result << endl;
#endif
  return result;
}

/** computes the information of an alignment column */
Physical
CorrectedMutualInformationScorer::computeInformation(const string& s) const {
  //string s = cleanAliColumn(sOrig); // not necessary !
  AlignmentColumn sCol(s, alphabet);
  return computeInformation(sCol);
}

/** return approximate correction term according to Basharin for an alphabet of size s and n characters */
double
CorrectedMutualInformationScorer::approximateCorrection(unsigned int s, unsigned int n) {
  return (s - 1.0) / (2.0 * LOG_2 * static_cast<double>(n));
}

double
CorrectedMutualInformationScorer::initialUncertainty(unsigned int n)
{
  return lg2(static_cast<double>(n));
}

/** returns frequency of n'th alphabet character in alignment column (not counting gaps) */
double 
CorrectedMutualInformationScorer::frequency2(const AlignmentColumn& col, 
					     unsigned int n) const {
  unsigned int enumerator = col.getCount(n);
  unsigned int denom = col.getNonGapCount();
  if (denom == 0) {
    return 0.0;
  }
  return static_cast<double>(enumerator)/static_cast<double>(denom);
}

/** returns frequency of character pair (n1, n2) in alignment column (not counting gaps) */
double 
CorrectedMutualInformationScorer::frequency2(const AlignmentColumn& col1, char c1,
			       const AlignmentColumn& col2, char c2) const {
  
  unsigned int minSize = col1.size();
  unsigned int num = 0;
  unsigned int numGap = 0;
  for (unsigned int i = 0; i < minSize; ++i) {
    if (alphabet.isGap(col1.getChar(i)) || alphabet.isGap(col2.getChar(i))) {
      ++numGap;
    }
    else if ((col1.getChar(i) == c1) && (col2.getChar(i) == c2)) {
      ++num;
    }
  }
  if (numGap >= minSize) {
    return 0.0;
  }
  // unsigned int num = count(col.begin(), col.end(), c);
  return static_cast<double>(num) / static_cast<double>(minSize-numGap);
}

/** returns frequency of character pair (n1, n2) in alignment column (also counting gaps) */
double 
CorrectedMutualInformationScorer::frequency(const AlignmentColumn& col1, char c1,
			      const AlignmentColumn& col2, char c2) const {
  
  unsigned int minSize = col1.size();
  if (minSize == 0) {
    return 0.0;
  }
  unsigned int num = 0;
  for (unsigned int i = 0; i < minSize; ++i) {
    if ((col1.getChar(i) == c1) && (col2.getChar(i) == c2)) {
      ++num;
    }
  }
  // unsigned int num = count(col.begin(), col.end(), c);
  return static_cast<double>(num) / static_cast<double>(minSize);
}

/** Computes uncorrected uncertainty of an alignment column with respect to an alphabet */
double
CorrectedMutualInformationScorer::uncorrectedUncertainty(const AlignmentColumn& col) const
{
  double result = 0.0;
  for (unsigned int i = 0; i < alphabet.size(); ++i) {
    double fi = frequency2(col, i);
    if (fi > 0.0) {
      result -= fi * lg2(fi);
    }
  }
  return result;
}

/** used for computation of variance of uncertainty.
 * According to Basharin 1959 
 */
double
CorrectedMutualInformationScorer::uncorrectedUncertaintyVariance(const AlignmentColumn& col) const
{
  double result = 0.0;
  for (unsigned int i = 0; i < alphabet.size(); ++i) {
    double fi = frequency2(col, i);
    if (fi > 0.0) {
      double lgf = lg2(fi);
      result += fi * lgf * lgf;
    }
  }
  return result;
}

/** Computes uncorrected uncertainty of an alignment column with respect to an alphabet.
 * Assumes that columns are already cleaned from gaps! 
 */
double
CorrectedMutualInformationScorer::uncorrectedUncertainty(
							 const AlignmentColumn& col1,
							 const AlignmentColumn& col2) const
{
  double result = 0.0;
  for (unsigned int i = 0; i < alphabet.size(); ++i) {
    for (unsigned int j = 0; j < alphabet.size(); ++j) {
      double fij = frequency(col1, alphabet.getChar(i), col2, alphabet.getChar(j)); // how often characters i and j in same row?
      if (fij > 0.0) {
	result -= fij * lg2(fij);
      }
    }
  }
  return result;
}

/** Computes uncorrected uncertainty of an alignment column with respect to an alphabet.
 * Assumes that columns are already cleaned from gaps! 
 */
double
CorrectedMutualInformationScorer::uncorrectedUncertaintyVariance(
				 const AlignmentColumn& col1,
				 const AlignmentColumn& col2) const
{
  double result = 0.0;
  double lgf = 0.0; // temporary variable
  for (unsigned int i = 0; i < alphabet.size(); ++i) {
    for (unsigned int j = 0; j < alphabet.size(); ++j) {
      double fij = frequency(col1, alphabet.getChar(i), col2, alphabet.getChar(j)); // how often characters i and j in same row?
      if (fij > 0.0) {
	lgf = lg2(fij);
	result += fij * lgf * lgf; 
      }
    }
  }
  return result;
}

/* begin module calehnb */
void 
CorrectedMutualInformationScorer::calehnb(long n, long gna, long gnc, long gng, long gnt, 
			    double* hg, double* ehnb, double* varhnb) const
// long n, gna, gnc, gng, gnt;
// double *hg, *ehnb, *varhnb;
{
  const int maxsize = 200; // maximum size for logarithm
  const long accuracy = 10000;
  /* ; debugging: boolean */
  /* calculate e(hnb) in bits/bp (ehnb) for a number (n) of example sequence
sites.  gna to gnt are the composition to use for the genome probabilities
of a to t.  the genomic uncertainty hg and the variance var(hnb)
(=varhnb) are also calculated. if the variable debugging is passed to the procedure then the individual
combinations of hnb are displayed.

        note: this procedure should not be broken into smaller
	procedures so that it remains efficient.
        version = 3.02; of procedure calehnb 1983 nov 23 */
  /* less than (1/accuracy) bits error is demanded
  for the sum of pnb (see variable 'total') at the end of the procedure */

  double log2 = log(2.0);   /* natural log of 2, used to find log base 2 */
  double logn;   /* log of n */
  double nlog2;   /* n * log2 */

  long gn;   /* sum of gna..gnt */
  double logpa, logpc, logpg, logpt;   /* logs of genome probabilities */

  /* log of n factorial is the sum of i=1 to n of log(i).
  the array below represents these logs up to n */
  double logfact[maxsize + 1];

  /* precalculated values of -p*log2(p), where p=nb/n for
     nb = 0 .. n.  m stands for minus */
  double mplog2p[maxsize + 1];

  long i;   /* index for logfact and mplog2p */
  double logi;   /* natural log of i */

  long na;
  long nc = 0, ng = 0, nt = 0;   /* numbers of bases in a site */
  bool done = false;   /* true when the loop is completed */

  double pnb;
  /* multinomial probability of a combination
                of na, nc, ng, nt */
  double hnb;   /* uncertainty for a combination of na..nt */
  double pnbhnb;   /* pnb*hnb, an intermediate result */
  double sshnb = 0.0;   /* sum of squares of hnb */

  /* variables for testing program correctness: */
  double total = 0.0;
  /* sum of pnb over all combinations of na..nt
     if this is not 1.00, the program is in error */
  long counter = 0;

  /* counts the number of times through
     the loop */

  /* prevent access to outside the arrays: */
  if (n > maxsize) {
    ERROR("n larger maximum size!", exception);
  }

  logn = log((double)n);
  nlog2 = n * log2;

  /* get logs of genome probabilities */
  gn = gna + gnc + gng + gnt;
  logpa = log((double)gna / gn);
  logpc = log((double)gnc / gn);
  logpg = log((double)gng / gn);
  logpt = log((double)gnt / gn);

  /* find genomic uncertainty */
  *hg = -((gna * logpa + gnc * logpc + gng * logpg + gnt * logpt) / (gn * log2));

  *ehnb = 0.0;   /* start error uncertainty at zero */

  /* make table of log of n factorial up to n
     and entropies for nb/n */
  logfact[0] = 0.0;   /* factorial(0) = 0 */
  mplog2p[0] = 0.0;
  for (i = 1; i <= n; i++) {
    logi = log((double)i);
    logfact[i] = logfact[i-1] + logi;
    mplog2p[i] = i * (logn - logi) / nlog2;
  }

  /* begin by looking at the combination with all a: na = n */
  na = n;

  /* the following loop simulates a number of nested loops
  of the form:
     for b1=a to t do
        for b2=b1 to t do
           for b3=b2 to t do
              ...
                 for bn=b(n-1) to t do ...
  the resulting set of variables increase in alphabetic order
  since no inner loop variable can have a value less than any
  outer loop.  the number of times through the inner-most loop
  is given by:
     o = (n + 1)*(n + 2)*(n + 3)/6
  in the case where there are four symbols (a,c,g,t) and n is
  the number of nested loops.
     a recursive set of loops would be possible, but it
  would use up too much memory in practical cases (up to n=150
  or higher).  a second algorithm sequests the loop variables
  into an array and increments them there.  however, the goal
  is to get all possible combinations for na, nc, ng, nt, where
  the sum of these is n.  the nested loops provide all the
  combinations in alphabetic order, assuring that there can not
  be any duplicates.  to find nb (one of na..nt) one would look
  at which of the variables b1 to bn were of value b.  this is
  a wasteful operation.
     the loop below simulates the array of control variables
  by changing each nb directly.
  */

  do {
    /* pnb is calculated by taking the log of the expression

               fact(n)          na     nc     ng     nt
    pnb = ------------------- pa   * pc   * pg   * pt  .
          fact(na).. fact(nt)

    log(pnb) generates a series of sums, allowing
    the calculation to proceed by addition and
    multiplication rather than multiplication and
    exponentiation.  the factorials become tractable
    in this way */

    pnb = exp(logfact[n] - logfact[na] - logfact[nc] - logfact[ng] -
	      logfact[nt] + na * logpa + nc * logpc + ng * logpg + nt * logpt);
	/* n factorial */

    hnb = mplog2p[na] + mplog2p[nc] + mplog2p[ng] + mplog2p[nt];

    pnbhnb = pnb * hnb;

    *ehnb += pnbhnb;

    sshnb += pnbhnb * hnb;   /* sum of squares of hnb */

    /* the following section keeps track of the calculation
    and writes out the current set of nb. */
    counter++;
    /*         if debugging then begin
                write(output,' ',counter:2,' ');
                for i := 1 to na do write(output,'a');
                for i := 1 to nc do write(output,'c');
                for i := 1 to ng do write(output,'g');
                for i := 1 to nt do write(output,'t');
                write(output,' ',na:3,nc:3,ng:3,nt:3);
                writeln(output,' pnb = ',pnb:10:5);
             end;  */
    total += pnb;

    /* the remaining portion of this repeat loop generates
    the values of na, nc, ng and nt.  notice that
    there are 7 possibilities at each loop increment.
    other than the stop, in each case the sum of
    na+nc+ng+nt remains constant (=n). */
    if (nt > 0) {  /* ending on a t - do outer loops */
      if (ng > 0) {  /* turn g into t */
	ng--;
	nt++;
      } else if (nc > 0) {
	/* turn one c into g,
	   and all t to g (note ng = 0 initially) */
	nc--;
	ng = nt + 1;
	nt = 0;
      } else if (na > 0) {
	/* turn one a into c and
	   all g and t to c. (note ng=nc=0 initially) */
	na--;
	nc = nt + 1;
	nt = 0;
      } else
	done = true;   /* since nt = n */
    } else {
      if (ng > 0) {  /* turn g into t */
	ng--;
	nt++;
      } else if (nc > 0) {  /* turn c into g */
	nc--;
	ng++;
      } else {
	na--;
	nc++;
	/* na > 0; turn a into c */
      }
    }
  } while (!done);

  /* no t - increment innermost loop */
  /* final adjustment: we only have the sum of squares so far */
  *varhnb = sshnb - *ehnb * *ehnb;

  /* if this message appears, there is either a bug in the code or
     the computer cannot be as accurate as requested */
  if (accuracy != (long)floor(accuracy * total + 0.5)) {
    printf(" procedure calehnb: the sum of probabilities is\n");
    printf(" not accurate to one part in %ld\n", (long)accuracy);
    printf(" the sum of the probabilities is %10.8f\n", total);
  }

  /* if this message appear, then there is an error in the
     repeat-until loop: it did not repeat as many times as
     is expected from the algorithm */
  if (counter == (long)floor((n + 1.0) * (n + 2) * (n + 3) / 6 + 0.5))
    return;
  /*      writeln(output, '    total: ',total:10:5);
        writeln(output,'    count = ',counter:1);
        writeln(output,'    (n+1)*(n+2)*(n+3)/6 = ',
                            round((n+1)*(n+2)*(n+3)/6):1); */
  ERROR(" procedure calehnb: program error, the number of calculations is in error\n", exception);
}  /* calehnb */

/** Returns information of single column.
 * uses approximate correction term a la Basharin
 */
Physical
CorrectedMutualInformationScorer::computeApproximateInformation(const AlignmentColumn& col) const {
  if (col.getNonGapCount() == 0) {
    return Physical(0.0, 0.0);
  }
  double h = uncorrectedUncertainty(col);
  double result = initialUncertainty(alphabet.size()) - h - approximateCorrection(alphabet.size(), col.getNonGapCount());
  double stddev = (uncorrectedUncertaintyVariance(col) - (h*h)) / col.getNonGapCount();
  stddev = sqrt(fabs(stddev));
  return Physical(result, stddev);
}

/** Returns information of single column.
 * Only use P log P  like Tom Schneider,
 * EXACT correction due to random sampling or approxamate correction depending on number of sequences
*/
Physical
CorrectedMutualInformationScorer::computeInformation(const AlignmentColumn& col) const {
  // const unsigned int randomSampleMaxSeq = 100;
  const unsigned int randomSampleSquareNum = 1000;
  const double entropyStdMax = 0.001;
  unsigned int effSize = col.getNonGapCount(); // number of non-gap characters
#ifdef DEBUG_VERBOSE
  cout << "Starting computeInformation(AlignmentCol): " << effSize << endl;
#endif
  if (effSize == 0) {
#ifdef DEBUG_VERBOSE
    cout << "Warning: no characters defined!" << endl;
#endif
    return Physical(0.0, 0.0);
  }
  Physical physical;
  if (!correctionMode) {
    return Physical(initialUncertainty(alphabet.size()) - uncorrectedUncertainty(col), 0.0); // no error correction in this mode
  }
  if (effSize >= 50) {
#ifdef DEBUG_VERBOSE
    cout << "Compute approximate information!" << endl;
#endif
    physical = computeApproximateInformation(col); // use approximation formula!
  }
  else if (alphabet.size() == 4) { // only call calehnb for nucleotide alphabets of size 4
    double uncorrUncertainty = uncorrectedUncertainty(col);
    // call original calehnb:
    double hg = 0.0;
    double ehnb = 0.0;
    double varhnb = 0.0;
    long count = 1;
    calehnb((long)effSize, count, count, count, count, &hg, &ehnb, &varhnb);
    double result = ehnb - uncorrUncertainty; // return expected uncertainty - uncertainty = information: decrease in information
#ifdef DEBUG_VERBOSE
    cout << "Uncorrected results and calehnb: " << result << " " 
	 << uncorrUncertainty << " " << hg << " " << ehnb << " " 
	 << varhnb << endl;
#endif
    double stddev = sqrt(varhnb); // TODO not correct yet!????
    physical = Physical(result, stddev);
  }
  else { // generate random sequences:
    double uncorrUncertainty = uncorrectedUncertainty(col);
    double sum = 0.0;
    double squareSum = 0.0;
    double term;
    Random& rnd = Random::getInstance();
    string ranSeq1(effSize, 'X');
    unsigned int n = 0;
    unsigned int minN = 10;
    for (unsigned int i = 0; i < randomSampleSquareNum; ++i) {
      generateRandomSequence(ranSeq1, rnd);
      term = uncorrectedUncertainty(AlignmentColumn(ranSeq1, alphabet));
      sum += term;
      squareSum += term * term;
      ++n;
      if ((n > minN) && (sqrt(varianceFromSum(sum, squareSum, n)/n) < entropyStdMax)) {
	break;
      }
    }
    double avg = sum / n;
    // correction = initial - avg; // this is the bias towards lower uncertainties : has to be compensated later!
    double resultError = sqrt(varianceFromSum(sum, squareSum, n));    
    double result = avg - uncorrUncertainty; 
    physical = Physical(result, resultError);
  }
#ifdef DEBUG_VERBOSE
  cout << "Finished computeInformation(AlignmentCol)" << physical << endl;
#endif
  return physical;
}

/** generates random sequence. Does not changes size of string! */
void
CorrectedMutualInformationScorer::generateRandomSequence(string& result,
					   Random& rnd) const
{
  for (unsigned int i = 0; i < result.size(); ++i) {
    result[i] = alphabet.getChar(rnd.getRand() % alphabet.size());
  }
}

/** returns (the error corrected) mutual information and an estimate of the small sample noise as a standard deviation.
 * Assumes that rows which contain at least one gap have been removed already from both alignment columns. 
 */
Physical
CorrectedMutualInformationScorer::computeInformationClean(const AlignmentColumn& col1, 
							  const AlignmentColumn& col2) const 
{
  unsigned int effSize = col1.getNonGapCount();
  if (effSize == 0) {
    return Physical(0.0, 0.0);
  }
  double resultError = 0.0;
  double initial = initialUncertainty(alphabet.size()*alphabet.size());
  double h = uncorrectedUncertainty(col1, col2);
  if (!correctionMode) {
    return  Physical(initial - h, 0.0); // uncorrected decrease in uncertainty
  }
  double correction = 0.0;
  const unsigned int randomSampleMaxSeq = 100;
  const unsigned int randomSampleSquareNum = 1000;
  const double entropyStdMax = 0.001;
  if (effSize < randomSampleMaxSeq) {
    double sum = 0.0;
    double squareSum = 0.0;
    double term;
    Random& rnd = Random::getInstance();
    string ranSeq1(effSize, 'X');
    string ranSeq2(effSize, 'X');
    unsigned int n = 0;
    unsigned int minN = 10;
    for (unsigned int i = 0; i < randomSampleSquareNum; ++i) {
      generateRandomSequence(ranSeq1, rnd);
      generateRandomSequence(ranSeq2, rnd);
      term = uncorrectedUncertainty(AlignmentColumn(ranSeq1, alphabet), 
				    AlignmentColumn(ranSeq2, alphabet));
      sum += term;
      squareSum += term * term;
      ++n;
      if ((n > minN) && (sqrt(varianceFromSum(sum, squareSum, n)/n) < entropyStdMax)) {
	break;
      }
    }
    double avg = sum / n;
    correction = initial - avg; // this is the bias towards lower uncertainties : has to be compensated later!
    resultError = sqrt(varianceFromSum(sum, squareSum, n));
  }
  else { // correction term
    correction = approximateCorrection(alphabet.size()*alphabet.size(), effSize);
    resultError = (uncorrectedUncertaintyVariance(col1, col2)-(h*h)) / col1.getNonGapCount(); // see Basharin paper
    resultError = sqrt(fabs(resultError));
  }
  double resultVal = initial - h - correction; // sum p log p
  return Physical(resultVal, resultError);
}

/** returns two string without mutual gaps or individual gaps */
pair<string, string> 
CorrectedMutualInformationScorer::cleanAliColumns(const string& s1Orig, const string& s2Orig) const
{
  Vec<unsigned int> okRows;
  for (unsigned int i = 0; i < s1Orig.size(); ++i) {
    if (! (alphabet.isGap(s1Orig[i]) || alphabet.isGap(s2Orig[i])) ) {
      okRows.push_back(i);
    }
  }
  string s1 = getSubset(s1Orig, okRows);
  string s2 = getSubset(s2Orig, okRows);
  return pair<string, string>(s1, s2);
}

/** returns two string without mutual gaps or individual gaps */
string
CorrectedMutualInformationScorer::cleanAliColumn(const string& s1Orig) const
{
  Vec<unsigned int> okRows;
  for (unsigned int i = 0; i < s1Orig.size(); ++i) {
    if (! alphabet.isGap(s1Orig[i])) {
      okRows.push_back(i);
    }
  }
  string s1 = getSubset(s1Orig, okRows);
  return s1;
}




