//
// Programmer:    Craig Stuart Sapp <craig@ccrma.stanford.edu>
// Creation Date: Sat Jul  1 02:48:55 PDT 2006
// Last Modified: Mon Jul  3 06:25:05 PDT 2006 (converted from hicor)
// Last Modified: Thu Oct  5 21:39:33 PDT 2006 (added color from input file)
// Last Modified: Wed Jul 11 23:12:42 PDT 2007 (added smoothing option)
// Filename:      ...sig/examples/all/polyhicor.cpp
// Web Address:   http://sig.sapp.org/examples/museinfo/polyhicor/polyhicor.cpp
// Syntax:        C++; museinfo
//
// Description:   Generate a comparison between multiple sets of 
//                number sequences.  
//

#include "Array.h"
#include "Options.h"
#include "PixelColor.h"

#include <string.h>
#include <stdio.h>
#include <math.h>

#ifndef OLDCPP
   using namespace std;
   #include <iostream>
   #include <fstream>
   #include <iomanip>
#else
   #include <iostream.h>
   #include <fstream.h>
   #include <iomanip.h>
#endif


#define COLOR_BW  0
#define COLOR_HUE 1

#define SHAPE_TRIANGLE  0
#define SHAPE_RECTANGLE 1


// function declarations:
void     checkOptions(Options& opts, int argc, char** argv);
void     example(void);
void     usage(const char* command);
void     calculateCorrelations(Array<Array<double> >& correlations,
                                  Array<double>& a, 
                                  Array<double>& b);
double   pearsonCorrelation(int size, double* a, double* b);
double   getMean(int size, double* a);
void     printPolyImageTriangle(Array<Array<Array<double> > >& correlations);
void     printPolyImageRectangle(Array<Array<Array<double> > >& correlations);
void     printCorrelationPixel(int valueindex, int count, 
                                  double correlation, int style);
int      getBestIndex(Array<Array<Array<double> > >& correlations, 
		                  int i, int j, int targetindex);

void     printArea(Array<const char*>& pids, 
                                 Array<Array<Array<double> > >& correlationset);
void     readData(Array<Array<double> >& data, 
                                  const char* hfile);
void     smoothData(Array<Array<double> >& data, double gain);
void     unsmoothData(Array<Array<double> >& data, double gain);
void     smoothSequence(Array<double>& sequence, double gain);
void     unsmoothSequence(Array<double>& sequence, double gain);


// User interface variables:
Options   options;
int       dataq     = 0;              // debug: display input data only with -d
int       colortype = COLOR_HUE;      // used with various color options
int       plotshape = SHAPE_TRIANGLE; // used with -r option
int       targetindex = 0;            // used with -n option
int       scaleq    = 0;              // used with -s option
int       areaQ     = 0;              // used with -A option
int       lowlimit  = 0;              // used with -A option
double    smooth    = 0.0;            // used with -S option
int       unsmoothq = 0;              // used with -U option


Array<PixelColor> colorindex;
char pids[50000] = {0};
char labels[50000] = {0};
Array<const char*> pidindex;
Array<const char*> labelindex;
const char null[2] = ".";


//////////////////////////////////////////////////////////////////////////

int main(int argc, char** argv) {
   // process the command-line options
   checkOptions(options, argc, argv);

   colorindex.setSize(1000);
   colorindex.setSize(0);

   Array<Array<double> > data;
   data.setSize(100);
   data.setSize(0);
   data.setGrowth(100);
   data.allowGrowth(1);

   const char* filename = "";

   if (options.getArgCount() <= 0) {
      filename = "<STDIN>";
   
   } else {
      filename = options.getArg(1);
   }

   readData(data, filename);

   if (smooth != 0.0) {
      if (!unsmoothq) {
         smoothData(data, smooth);
      } else {
         unsmoothData(data, smooth);
      }
   }


/*
   int inputsize = 0;
   double inputvalues[200];

   int ii;
   char *ptr;
   double value;
   const char* blanks = " \t,:";
   char templine[4096];
   while (!infile.eof()) {
      infile.getline(templine, 4096, '\n');
      if (infile.eof() && (strcmp(templine, "") == 0)) {
         break;
      } else if (isdigit(templine[0]) || templine[0] == '-' ||
                  templine[0] == '+') {
         ptr = strtok(templine, blanks);
         inputsize = 0;
         while (ptr != NULL && sscanf(ptr, "%lf", &value) == 1) {
            inputvalues[inputsize++] = value;
            if (inputsize >= 100) {
               break;
            }
            ptr = strtok(NULL, blanks);
         }
         if (inputsize > 0 && data.getSize() == 0) {
            data.setSize(inputsize);
            for (ii=0; ii<data.getSize(); ii++) {
               data[ii].setSize(1000);
               data[ii].setGrowth(1000);
               data[ii].allowGrowth(1);
               data[ii].setSize(0);
            }
         }
         for (ii=0; ii<data.getSize() && ii < inputsize; ii++) {
            data[ii].append(inputvalues[ii]);
         }
      } else if (strncmp("*c=", templine, 3) == 0) {
         ptr = strtok(templine, blanks);
         while (ptr != NULL) {
            if (strncmp("*c=", ptr, 3) == 0) {
               newcolor = PixelColor::getColor(ptr+3);
               colorindex.append(newcolor);
	       if (!areaQ) {
                  cerr << "Storing color: " << ptr + 3 << endl;
               }
            } else {
               cerr << "Error in string: " << ptr << endl;
            }
	    ptr = strtok(NULL, blanks);
         }
	 colorindex.allowGrowth(0);
      } else if (strncmp("!pid", templine, 4) == 0) {
         strcpy(pids, templine);
      }
   }

*/

   int i, j;
   if (dataq) {
      for (i=0; i<data[0].getSize(); i++) {
         for (j=0; j<data.getSize(); j++) {
            cout << data[j][i];
            if (j < data.getSize() - 1) {
               cout << "\t";
            }
         }
         cout << "\n";
      }
      exit(0);
   } 

   Array<Array<Array<double> > > correlationset;
   correlationset.setSize(data.getSize());


   // check the bounds of target index:
   if (targetindex < 0) {
      targetindex = 0;
   } else if (targetindex >= data.getSize()) {
      targetindex = 0;
   }

   for (i=0; i<correlationset.getSize(); i++) {
      correlationset[i].setSize(0);
      if (i == targetindex) {
         correlationset[i].allowGrowth(0);
         continue;
      }

      calculateCorrelations(correlationset[i], data[targetindex], data[i]);
      correlationset[i].allowGrowth(0);
   }

   if (areaQ) {
      printArea(pidindex, correlationset);
      return 0;
   }

   switch (plotshape) {
      case SHAPE_RECTANGLE:
         printPolyImageRectangle(correlationset);
         break;
      case SHAPE_TRIANGLE:
      default:
         printPolyImageTriangle(correlationset);
   }


   return 0;
}


//////////////////////////////////////////////////////////////////////////



//////////////////////////////
//
// smoothData -- 
//

void smoothData(Array >& data, double gain) {
   int i;
   for (i=0; i<data.getSize(); i++) {
      smoothSequence(data[i], gain);
   }
}



//////////////////////////////
//
// unsmoothData -- 
//

void unsmoothData(Array >& data, double gain) {
   int i;
   for (i=0; i<data.getSize(); i++) {
      unsmoothSequence(data[i], gain);
   }
}



//////////////////////////////
//
// smoothSequence -- smooth the sequence with a
//    symmetric exponential smoothing filter (applied in the forward
//    and reverse directions with the specified input gain.
//
//    Difference equation for smoothing: y[n] = k * x[n] + (1-k) * y[n-1]
//

void smoothSequence(Array& sequence, double gain) {
   double oneminusgain = 1.0 - gain;
   int i;
   int ssize = sequence.getSize();

   // reverse filtering first 
   for (i=ssize-2; i>=0; i--) {
      sequence[i] = gain*sequence[i] + oneminusgain*sequence[i+1];
   }

   // then forward filtering
   for (i=1; i<ssize; i++) {
      sequence[i] = gain*sequence[i] + oneminusgain*sequence[i-1];
   }

}



//////////////////////////////
//
// unsmoothSequence -- removed smoothed sequence from original sequence.
//

void unsmoothSequence(Array& sequence, double gain) {
   Array<double> smoothed = sequence;
   smoothSequence(smoothed, gain);
   int i;
   for (i=0; i<sequence.getSize(); i++) {
      sequence[i] -= smoothed[i];
   } 
}


//////////////////////////////
//
// printArea -- print the amount of area for each correlation
//


void printArea(Array<const char*>& pidindex,
      Array<Array<Array<double> > >& correlationset) {
   int setcount = correlationset.getSize();

   int mastercount = 0;
   Array<int> localcount;
   localcount.allowGrowth(0);
   localcount.setSize(setcount);
   localcount.setAll(0);

   int nontarget = targetindex + 1;
   if (nontarget >= correlationset.getSize()) {
      nontarget = 0;
   }
   
   int length = correlationset[0].getSize();
   if (length == 0) {
      length = correlationset[1].getSize();
   }

   int i;
   int j;
   int datawidth;
   int bestindex;
   for (i=0; i<length-1-lowlimit; i++) {
      datawidth = correlationset[nontarget][i].getSize();

      for (j=0; j<datawidth; j++) {
         bestindex = getBestIndex(correlationset, i, j, targetindex);
         mastercount++;
         localcount[bestindex]++;
      }
   }

   // print the results:
   for (i=0; i<setcount; i++) {
      if (i==targetindex) {
         cout << (char*)&(pidindex[i][0]) << ":\t" << "target" << endl;
      } else {
         if (strcmp(pidindex[i], "") == 0) {
            cout << i;
         } else {
            cout << (char*)&(pidindex[i][0]);
         }
         cout  << ":\t" 
               << fixed 
	       << (double)localcount[i] / mastercount 
	       << endl;
      }
   }

}



//////////////////////////////
//
// printPolyImageTriangle --
//

void printPolyImageTriangle(Array > >& correlations) {
   const char* background = options.getString("background");

   int nontarget = targetindex + 1;
   if (nontarget >= correlations.getSize()) {
      nontarget = 0;
   }
   
   int seqcount = correlations.getSize();
   int length = correlations[0].getSize();
   if (length == 0) {
      length = correlations[1].getSize();
   }

   int maxcolumn = length * 2;
   int maxrow    = (length - 1) * 2;

   cout << "P3\n";
   cout << maxcolumn << " " << maxrow << "\n";
   cout << "255\n";

   int i, j;
   int datawidth;
   int bgwidth;
   int bestindex;
   for (i=0; i<length-1; i++) {

      datawidth = correlations[nontarget][i].getSize();
      bgwidth = (maxcolumn - datawidth * 2)/2;

      for (j=0; j<bgwidth; j++) {
         cout << " " << background << " ";
      }
      for (j=0; j<datawidth; j++) {
         bestindex = getBestIndex(correlations, i, j, targetindex);
         printCorrelationPixel(bestindex, seqcount, 
                               correlations[bestindex][i][j], colortype);
      }
      for (j=0; j<bgwidth; j++) {
         cout << " " << background << " ";
      }
      cout << "\n";

      // repeat the row to double it (to make image square)      
      for (j=0; j<bgwidth; j++) {
         cout << " " << background << " ";
      }
      for (j=0; j<datawidth; j++) {
         bestindex = getBestIndex(correlations, i, j, targetindex);
         printCorrelationPixel(bestindex, seqcount, 
                               correlations[bestindex][i][j], colortype);
      }
      for (j=0; j<bgwidth; j++) {
         cout << " " << background << " ";
      }
      cout << "\n";
   }
}



//////////////////////////////
//
// getBestIndex -- return the position of the highest correlation.
//   return -1 if there are more than one highest correlation.
//

int getBestIndex(Array<Array<Array<double> > >& correlations, int i,
      int j, int targetindex) {

   // int doubled = 0;
   int bestindex;
   bestindex = 0;
   if (targetindex == 0) {
      bestindex = 1;
   }
   int index = 0;
   if (targetindex == 0) {
      index = 1;
   }
   if (index == bestindex) {
      index++;
   }
   if (index == targetindex) {
      index++;
   }

   for ( ; index < correlations.getSize(); index++) {
      if (index == targetindex) {
         continue;
      }
      if (correlations[index][i][j] > correlations[bestindex][i][j]) {
         bestindex = index;
      }
   }

   return bestindex;
}



//////////////////////////////
//
// printPolyImageRectangle --
//

void printPolyImageRectangle(Array > >& correlations) {

 return;

/*
   int maxcolumn = correlation[0].getSize() * 2;
   int maxrow    = (correlation[0].getSize()-1) * 2;

   cout << "P3\n";
   cout << maxcolumn << " " << maxrow << "\n";
   cout << "255\n";

   int i, j;
   int datawidth;
   int jstretch;

   for (i=0; i<correlation.getSize()-1; i++) {
      datawidth = correlation[i].getSize();

      for (j=0; j<maxcolumn/2; j++) {
         jstretch = int(double(j)/maxcolumn*2*datawidth);
         printCorrelationPixel(correlation[i][jstretch], colortype);
      }
      cout << "\n";

      // repeat for double pixel size
      for (j=0; j<maxcolumn/2; j++) {
         jstretch = int(double(j)/maxcolumn*2*datawidth);
         printCorrelationPixel(correlation[i][jstretch], colortype);
      }
      cout << "\n";

   }
*/
}


 
//////////////////////////////
//
// printCorrelationPixel --
//

void printCorrelationPixel(int valueindex, int count, double correlation, 
      int style) {
   double value = (double)valueindex / count;

   PixelColor pc;

   switch (style) {
      case COLOR_HUE:
         if (valueindex < colorindex.getSize()) {
            pc = colorindex[valueindex];
         } else {
            cerr << "Value = " << value << endl;
            pc.setHue(value);
         }
         break;
      case COLOR_BW:
      default:
         pc.setGrayNormalized(value);
   }


   // add correlation highlighting back in later...

   /*   if (value == 5.0/6.0) {
      pc.setRed(146);
      pc.setGreen(11);
      pc.setBlue(241);
   }

   if (scaleq == 0) {
         if (value == 2.0/6.0) {
         pc.setRed(31);
         pc.setGreen(162);
         pc.setBlue(31);
      }
   }

   if (scaleq) {
      correlation = (correlation + 1.0)/2.0;
      pc.setRed(int(pc.getRed() *  correlation + 0.5));
      pc.setGreen(int(pc.getGreen() *  correlation + 0.5));
      pc.setBlue(int(pc.getBlue() *  correlation + 0.5));
   }
   */


   cout << " ";
   pc.writePpm3(cout);
   cout << " ";
	   
   // repeat the cell for a square image
   cout << " ";
   pc.writePpm3(cout);
   cout << " ";

}



//////////////////////////////
//
// calculateCorrelations --
//

void calculateCorrelations(Array<Array<double> >& correlations,
      Array<double>& a, Array<double>& b) {

   int i, j;
   correlations.setSize(a.getSize());
   for (i=0; i<a.getSize(); i++) {
      correlations[i].setSize(i+1);
      correlations[i].allowGrowth(0);
   }

   double *aa;
   double *bb;
   int rowsize;
   int corelsize;

   for (i=0; i<correlations.getSize()-1; i++) {
      rowsize = correlations[i].getSize();
      corelsize = a.getSize() - i;
      for (j=0; j<rowsize; j++) {
         aa = a.getBase() + j;
         bb = b.getBase() + j;
         correlations[i][j] = pearsonCorrelation(corelsize, aa, bb);
      }
   }

   // set the bottom row to 1.0 correlations
   i = correlations.getSize()-1;
   for (j=0; j<correlations[i].getSize(); j++) {
      correlations[i][j] = 1.0;
   }

}



//////////////////////////////
//
// pearsonCorrelation --
//

double pearsonCorrelation(int size, double* a, double* b) {
   double meana = getMean(size, a);
   double meanb = getMean(size, b);

   double topsum     = 0.0;
   double bottomasum = 0.0;
   double bottombsum = 0.0;

   int i;
   for (i=0; i<size; i++) {
      topsum += (a[i] - meana) * (b[i] - meanb);
      bottomasum += (a[i] - meana) * (a[i] - meana);
      bottombsum += (b[i] - meanb) * (b[i] - meanb);
   }

   if (bottomasum == 0.0 || bottombsum == 0.0) {
      return 0.0;
   }
   return topsum / sqrt(bottomasum * bottombsum);
}



//////////////////////////////
//
// getMean --
//

double getMean(int size, double* a) {
   if (size <= 0) {
      return 0.0;
   }

   int i;
   double sum = 0.0;
   for (i=0; i<size; i++) {
      sum += a[i];
   }

   return sum / size;
}



//////////////////////////////
//
// checkOptions -- 
//

void checkOptions(Options& opts, int argc, char* argv[]) {
   opts.define("d|data=b",    "display input data and then quit");
   opts.define("n=i:1",    "which column of data is the reference");
   opts.define("s|scale=b", "scale the coloring by the correlation value");
   opts.define("A|area=i:0", "Area measurement");
   opts.define("c|correlation=b", "display correlation data and then quit");
   opts.define("bw=b", "display correlations in black&white coloring");
   opts.define("r|rectangle=b", "display correlations in rectangular plot");
   opts.define("b|bg|background=s:255 255 255", "background color");
   opts.define("S|smooth=d:0.0", "Smoothing gain (0.0 = don't smooth)");
   opts.define("U|unsmooth=b", "apply unsmoothing");

   opts.define("author=b",  "author of program"); 
   opts.define("version=b", "compilation info");
   opts.define("example=b", "example usages");   
   opts.define("help=b",  "short description");
   opts.process(argc, argv);
   
   // handle basic options:
   if (opts.getBoolean("author")) {
      cout << "Written by Craig Stuart Sapp, "
           << "craig@ccrma.stanford.edu, Jul 2006" << endl;
      exit(0);
   } else if (opts.getBoolean("version")) {
      cout << argv[0] << ", version: 1 Jul 2006" << endl;
      cout << "compiled: " << __DATE__ << endl;
      exit(0);
   } else if (opts.getBoolean("help")) {
      usage(opts.getCommand());
      exit(0);
   } else if (opts.getBoolean("example")) {
      example();
      exit(0);
   }

   smooth      = opts.getDouble("smooth");
   unsmoothq   = opts.getBoolean("unsmooth");
   targetindex = opts.getInteger("n") - 1;
   scaleq      = opts.getBoolean("scale");
   areaQ       = opts.getBoolean("area");
   lowlimit    = opts.getInteger("area");
   if (lowlimit < 0) {
      lowlimit = 0;
   }

   dataq   = opts.getBoolean("data");
   if (opts.getBoolean("bw")) {
      colortype = COLOR_BW;
   }
   
   if (opts.getBoolean("rectangle"))  {
      plotshape = SHAPE_RECTANGLE;
   }
}



//////////////////////////////
//
// example --
//

void example(void) {


}



//////////////////////////////
//
// usage --
//

void usage(const char* command) {

}



//////////////////////////////
//
// readData -- read the input data array.
//

void readData(Array >& data, const char* filename) {
   data.setSize(500);
   data.setSize(0);
   data.setGrowth(500);

   PixelColor newcolor;
   data.allowGrowth(1);

   ifstream infile;
   #ifndef OLDCPP
      infile.open(filename, ios::in);
   #else
      infile.open(filename, ios::in | ios::nocreate);
   #endif

   if (!infile.is_open()) {
      cerr << "Cannot open file for reading: " << filename << endl;
      exit(1);
   }

   int inputsize = 0;
   double inputvalues[500];

   int ii;
   char *ptr;
   double value;
   const char* blanks = " \t,:";
   int foundlabels = 0;
   char templine[100000];
   while (!infile.eof()) {
      infile.getline(templine, 90000, '\n');
      if (!foundlabels) {
	 if (!((strncmp(templine, "pid",  3) == 0) ||
	       (strncmp(templine, "!!",   2) == 0) ||
	       (strncmp(templine, "*",    1) == 0) ||
	       (strncmp(templine, "!pid", 4) == 0))) {
            foundlabels = 1;
            strcpy(labels, templine);
            continue;
	 }
      }
      if (infile.eof() && (strcmp(templine, "") == 0)) {
         break;
      } else if (isdigit(templine[0]) || templine[0] == '-' ||
                  templine[0] == '+') {
         ptr = strtok(templine, blanks);
         inputsize = 0;
         while (ptr != NULL && sscanf(ptr, "%lf", &value) == 1) {
            inputvalues[inputsize++] = value;
            if (inputsize >= 100) {
               break;
            }
            ptr = strtok(NULL, blanks);
         }
         if (inputsize > 0 && data.getSize() == 0) {
            data.setSize(inputsize);
            for (ii=0; ii<data.getSize(); ii++) {
               data[ii].setSize(5000);
               data[ii].setGrowth(5000);
               data[ii].allowGrowth(1);
               data[ii].setSize(0);
            }
         }
         for (ii=0; ii<data.getSize() && ii < inputsize; ii++) {
            data[ii].append(inputvalues[ii]);
         }
      } else if (strncmp("*c=", templine, 3) == 0) {
         ptr = strtok(templine, blanks);
         while (ptr != NULL) {
            if (strncmp("*c=", ptr, 3) == 0) {
               newcolor = PixelColor::getColor(ptr+3);
               colorindex.append(newcolor);
	       //if (!areaQ) {
               //   cerr << "Storing color: " << ptr + 3 << endl;
               //}
            } else {
               cerr << "Error in string: " << ptr << endl;
            }
	    ptr = strtok(NULL, blanks);
         }
	 colorindex.allowGrowth(0);
      } else if (strncmp("!pid", templine, 4) == 0) {
         strcpy(pids, templine);
      } else if (strncmp("pid", templine, 3) == 0) {
         strcpy(pids, templine);
      }
   }


   pidindex.setSize(data.getSize());
   pidindex.allowGrowth(0);
   int i;
   int length = strlen(pids);

   for (i=0; i<pidindex.getSize(); i++) {
      pidindex[i] = null;
   }

   int counter = 0;
   for (i=0; i<length; i++) {
      if (strncmp(&(pids[i]), "pid", 3) == 0) {
         pidindex[counter++] = &(pids[i]);
         i += 3;
         continue;
      }
      if (pids[i] == '\t') {
         pids[i] = '\0';
      }
      if (pids[i] == ' ') {  // probably not necessary
         pids[i] = '\0';
      }
   }


   labelindex.setSize(data.getSize());
   labelindex.allowGrowth(0);
   length = strlen(labels);

   for (i=0; i<labelindex.getSize(); i++) {
      labelindex[i] = null;
   }

   counter = 0;
   labelindex[counter++] = &(labels[0]);
   for (i=0; i<length; i++) {
      if (labels[i] == '\t') {
         labels[i] = '\0';
         if (counter < labelindex.getSize()) {
            labelindex[counter++] = &(labels[i+1]);
         }
         i += 1;
         continue;
      }
   }

   for (i=0; i<pidindex.getSize(); i++) {
      if (strcmp(pidindex[i], ".") == 0) {
         if (i < labelindex.getSize()) {
            pidindex[i] = labelindex[i];
         }
      }
      if (pidindex[i][0] == '!') {
         pidindex[i] = &(pidindex[i][1]);
      }
   }

//   for (i=0; i<pidindex.getSize(); i++) {
//      cout << i << " :: " << pidindex[i] << endl;
//   }

	      

}


// md5sum: 18aed7a44e892db561571e1a6b458101 polyhicor.cpp [20080518]