#include <SequenceAlignmentTools.h>
#include <StringTools.h>
#include <iostream>
#include <vectornumerics.h>
#include <SimpleSequenceAlignment.h>

double
SequenceAlignmentTools::computeGCContent(const SequenceAlignment& ali,
					 SequenceAlignment::size_type start,
					 SequenceAlignment::size_type stop) {
  unsigned int charCount = 0;
  unsigned int gcCount = 0;
  for (SequenceAlignment::size_type i = 0; i < ali.size(); ++i) {
    string s = ali.getSequence(i);
    upperCase(s); // convert to upper case
    for (string::size_type j = start; j < stop; ++j) {
      char c = s[j];
      if (!SequenceAlignment::isGap(c)) {
	++charCount;
	if ((c == 'G') || (c == 'C')) {
	  ++gcCount;
	}
      }
    }
  }
  if (charCount == 0) {
    return 0.0;
  }
  return static_cast<double>(gcCount)/ static_cast<double>(charCount);
}

double
SequenceAlignmentTools::computeGCContent(const SequenceAlignment& ali) {
  return computeGCContent(ali, 0, ali.getLength());
}

Vec<double>
SequenceAlignmentTools::computeWindowGCContent(const SequenceAlignment& ali, SequenceAlignment::size_type winLength) {
  ASSERT(winLength > 0);
  SequenceAlignment::size_type numWin = ali.getLength() / winLength;
  if (numWin <= 0) {
    return Vec<double>();
  }
  Vec<double> result = Vec<double>(numWin, 0.0);
  for (SequenceAlignment::size_type i = 0; i < ali.getLength(); i+= winLength) {
    if ((i + winLength) < ali.getLength()) {
      ERROR_IF(i/winLength >= result.size(), "Internal error in computeWindowGCContent");
      result[i/winLength] = computeGCContent(ali, i, i + winLength);
    }
  }
  return result;
}

/** shuffled alignment by swapping one pair of columns, one of the columns is at position pos. 
 * Returns true if one pair of columns was found. */
bool
SequenceAlignmentTools::dinucleotideShuffleIterationPosition(SequenceAlignment& ali,
							     const string& alphabet,
							     size_type pos,
							     const Vec<Vec<double> >& origFrequencies,
							     double normLimit,
							     bool shuffleColumnMode) {
  bool ok = false;
  size_type maxTry = 5;
  size_type trial = 0;
  string posCol = ali.getColumn(pos);
  Random& rnd = Random::getInstance();
  if (shuffleColumnMode) {
    random_shuffle(posCol.begin(), posCol.end());
  }
  do {
    ++trial;
    size_type newPos = rnd.getRand(ali.getLength());
    if (newPos != pos) {
      string newCol = ali.getColumn(newPos);
      if (shuffleColumnMode) {
	random_shuffle(newCol.begin(), newCol.end());
      }
      ali.setColumn(posCol, newPos);
      ali.setColumn(newCol, pos);
      Vec<Vec<double> > frequencies = computeDinucleotideFrequencies(ali, alphabet);
      if (checkFrequenciesOk(frequencies, origFrequencies, normLimit)) {
	ok = true; // found feasable swap!
      }
      else { // reverse column swapping
	ali.setColumn(posCol, pos);
	ali.setColumn(newCol, newPos);
      }
    }
  }
  while ((!ok) && (trial < maxTry));
  return ok;
}

/** one iteration of dinucleotide-preserving shuffling */
SequenceAlignmentTools::size_type
SequenceAlignmentTools::dinucleotideShuffleIteration(SequenceAlignment& ali,
						     const string& alphabet,
						     const Vec<Vec<double> >& origFrequencies,
						     double normLimit,
						     bool shuffleColumnMode) {
  size_type swapCount = 0;
  Vec<unsigned int> indexSet = generateRandomIndexSubset(ali.getLength(), ali.getLength(), 0);
  for (size_type ii = 0; ii < ali.getLength(); ++ii) {
    if (dinucleotideShuffleIterationPosition(ali, alphabet, indexSet[ii], origFrequencies, 
					     normLimit, shuffleColumnMode)) {
      ++swapCount;
    }
  }
  return swapCount;
}

SequenceAlignmentTools::size_type
SequenceAlignmentTools::dinucleotideShuffle(SequenceAlignment& ali,
					    const string& alphabet,
					    double normLimit,
					    size_type iterations,
					    bool shuffleColumnMode) {
  SimpleSequenceAlignment aliOrig = ali; // make safety copy
  Vec<Vec<double> > origFrequencies = computeDinucleotideFrequencies(aliOrig, alphabet);
  size_type swapCount = 0;
  for (size_type i = 0; i < iterations; ++i) {
    swapCount += dinucleotideShuffleIteration(ali, alphabet, origFrequencies, 
					      normLimit, shuffleColumnMode);
  }
  Vec<Vec<double> > newFrequencies = computeDinucleotideFrequencies(ali, alphabet);
  ERROR_IF(!checkFrequenciesOk(newFrequencies, origFrequencies, normLimit),
	   "Internal error in dinucleotide-preserving shuffling!");
  return swapCount; // return number of successful column swaps
}

Vec<Vec<double> >
SequenceAlignmentTools::computeDinucleotideFrequencies(const SequenceAlignment& ali,
						       const string& alphabet) {
  double pseudo = 1.0; // Bayesian pseudo-count
  Vec<Vec<double> > result(alphabet.size(), Vec<double>(alphabet.size(), pseudo));
  for (size_type i = 0; i < alphabet.size(); ++i) {
    char c1 = alphabet[i];
    for (size_type j = 0; j < alphabet.size(); ++j) {
      char c2 = alphabet[j];
      size_type numDi = ali.countDiCharacter(c1, c2);
      result[i][j] += numDi;
    }
  }
  probabilityNormalize(result);   // normalize such that sum of all frequencies is one
  return result;
}

/** Checks if all dinucleotide frequencies differences are below normLimit */
bool
SequenceAlignmentTools::checkFrequenciesOk(const Vec<Vec<double> >& f1,
					   const Vec<Vec<double> >& f2,
					   double normLimit) {
  ASSERT(f1.size() == f2.size());
  for (size_type i = 0; i < f1.size(); ++i) {
    for (size_type j = 0; j < f1[i].size(); ++j) {
      double d = fabs(f1[i][j] - f2[i][j]);
      if ( d > normLimit) {
	// cout << "Frequency check failed: " << f1 << endl << f2 << endl;
	return false;
      }
    }
  }
  return true;
}

/** Returns concatenated sequences. If max == 0, max will be set to the total number of sequences of the alignment */
SequenceAlignment::sequence_type
SequenceAlignmentTools::generateConcatenatedSequences(const SequenceAlignment& ali,
						      SequenceAlignment::size_type min,
						      SequenceAlignment::size_type max) {
  if (max == 0) {
    max = ali.size();
  }
  SequenceAlignment::sequence_type result;
  for (SequenceAlignment::size_type i = min; i < max; ++i) {
    result = result + ali.getSequence(i);
  }
  return result;
}

