//
// Programmer:    Craig Stuart Sapp <craig@ccrma.stanford.edu>
// Creation Date: Wed Jul 31 17:36:57 PDT 2002
// Last Modified: Wed Jul 31 17:37:00 PDT 2002
// Filename:      ...sig/examples/all/iwsimplex.cpp
// Web Address:   http://sig.sapp.org/examples/museinfo/humdrum/iwsimplex.cpp
// Syntax:        C++; museinfo
//
// Description:   Interval weights optimization via the downhill simplex
//                method of Nelder and Mead.
//

#include "museinfo.h"

#include <string.h>
#include <stdlib.h>
#include <math.h>
#include <time.h>


// function declarations
void   checkOptions(Options& opts, int argc, char* argv[]);
void   example(void);
void   usage(const char* command);
void   getStartingWeights(Array<double>& weights,
                                 HumdrumFile& weightfile);
void   getChordInformation(Array<ArrayInt>& pitches, Array<int>& root,
                                 Array<double>& count, HumdrumFile& datafile);
double getErrors(Array<double>& weights,
                                 Array<ArrayInt>& pitches, Array<int>& root,
                                 Array<double>& count);
void   printWeights(Array<double> weights);
void   printKern(Array<double>& initialweights);


// Downhill simplex method functions
void amoeba(Matrix<double>& simplex, double ftol, 
      double(*testFunction)(Array<double>&), int& evalCount,
      int maxtests = 5000);
double amoebaTry(Matrix<double>& simplex, Array<double>& functionEvaluation,
      Array<double>& simplexSum, double(*testFunction)(Array<double>&),
      int worstPoint, double fac);
double testFunction(Array<double>& weights);
double runDownhillSimplexMethod(Array<double>& returnedWeights, 
      Array<double>& initialWeights);
void generateSimplex(Matrix<double>& simplex, Array<double>& initialWeights, 
      double delta);



// global variables
Options      options;                   // database for command-line arguments
int          debugQ      = 0;           // used with --debug option
int          verboseQ    = 0;           // used with -v option
double       tolerance   = 0.00000001;  // used with -t option
double       sidedelta   = 10.0;        // used with -s option
double       initdelta   = 0.01;        // initial value of sidedelta
int          trials      = 5000;        // used with -r option
double       decay       = 0.8;         // used with -d option
int          normQ       = 0;           // used with -n option
int          norm        = 25;          // used with -n option
int          repeatcount = 5;           // used with -c option


// global variables for testFunction
Array<ArrayInt> pitches;         // pitch class set of the chords
Array<int>      root;            // root of the chords
Array<double>   counte;           // frequency of the chord occurences

///////////////////////////////////////////////////////////////////////////

int main(int argc, char* argv[]) {
   HumdrumFile datafile;
   HumdrumFile weightfile;

   checkOptions(options, argc, argv); // process the command-line options

   datafile.read(options.getArg(1));
   weightfile.read(options.getArg(2));

   Array<double> initialweights; // starting weights of the search
   Array<double> bestweights;
   Array<double> testweights;
   Array<double> currentweights;

   getStartingWeights(initialweights, weightfile);
   getChordInformation(pitches, root, counte, datafile);

   double starterr = getErrors(initialweights, pitches, root, counte);
   double besterr = starterr;
   double lasterr = starterr + 1;
   int itercount = 0;
   int testcount = 0;

   currentweights = initialweights;
   bestweights = currentweights;
   while (lasterr > besterr || testcount < repeatcount) {
      itercount++;
      if (verboseQ) {
         cout << "Test number " << itercount << endl;
      }
      lasterr = besterr;
      besterr = runDownhillSimplexMethod(testweights, currentweights);
      sidedelta *= decay;
      if (besterr < lasterr) {
         bestweights = testweights;
         testcount = 0;
         currentweights = bestweights;
      } else {
         testcount++;
      }
   }

   cout << "!! Starting Error Case:\t" << starterr << "\n";
   cout << "!! Best Error Case:\t"     << besterr << "\n";
   cout << "!! Tolerance:\t"           << tolerance << "\n";
   cout << "!! Initial Delta:\t"       << initdelta << "\n";
   cout << "!! Current Delta:\t"       << sidedelta << "\n";
   cout << "!! Max Trials:\t"          << trials << "\n";
   cout << "!! Repeat Count:\t"        << repeatcount << "\n";
   printKern(bestweights);

   return 0;
}



///////////////////////////////////////////////////////////////////////////




//////////////////////////////
//
// printKern --
//

void printKern(Array& weights) {
   char buffer[1024] = {0};

   int i;
   Array<double> w;
   w = weights;
   if (normQ) {
      double shift = weights[2];
      double scale = weights[norm] - shift;
      for (i=0; i<w.getSize(); i++) {
         w[i] = (w[i] - shift)/scale;
      }
   }

   cout << "**kern\t**weight\n";

   for (i=0; i<w.getSize(); i++) {
      if (i==5||i==11||i==22||i==28||i==34) {
         continue;
      }
      cout << Convert::base40ToKern(buffer, i+3*40) << "\t"
           << w[i] << "\n";
   }

   cout << "*-\t*-\n";

}



//////////////////////////////
//
// getErrors -- try all chords and count how many root errors occured.
//

double getErrors(Array<double>& weights, Array<ArrayInt>& pitches,
      Array<int>& root, Array<double>& counte) {

   Array<double> rootscores;
   rootscores.setSize(40);
   rootscores.allowGrowth(0);
   int i, j, m;
   double errors = 0;
   int min;
   for (m=0; m<root.getSize(); m++) {
      rootscores.setAll(0.0);
      for (i=0; i<rootscores.getSize(); i++) {
         for (j=0; j<pitches[m].getSize(); j++) {
            rootscores[i] += weights[(pitches[m][j]-i+400)%40];
         }
      }
      min = 0;
      for (i=0; i<rootscores.getSize(); i++) {
         if (rootscores[min] > rootscores[i]) {
            min = i;
         }
      }
      if (root[m] != min+2) {
         if (debugQ) {
            cout << "Error: root=" << root[m]
                 << "\tbut measured: " << min << endl;
         }
         errors += counte[m];
      }
   }

   return errors;
}



//////////////////////////////
//
// getStartingWeights --
//

void getStartingWeights(Array& weights, HumdrumFile& weightfile) {
   weights.setSize(40);
   weights.setAll(100000);
   weights.allowGrowth(0);

   int i, j;
   int root;
   double weight;
   for (i=0; i<weightfile.getNumLines(); i++) {
      root = -1;
      weight = 10000;
      if (!weightfile[i].isData()) {
         continue;
      }
      for (j=0; j<weightfile[i].getFieldCount(); j++) {
         if (strcmp("**kern", weightfile[i].getExInterp(j)) == 0) {
            root = Convert::kernToBase40(weightfile[i][j]) % 40;
         } else if (strcmp("**weight", weightfile[i].getExInterp(j)) == 0) {
            weight = strtod(weightfile[i][j], NULL);
         }
      }
      if (root >= 0) {
         weights[root] = weight;
      }
   }

   if (debugQ) {
      printWeights(weights);
   }
}



//////////////////////////////
//
// printWeights --
//

void printWeights(Array weights) {
   char buffer[128] = {0};
   cout << "**kern\t**weight\n";
   int i;
   for (i=0; i<weights.getSize(); i++) {
      if (i==5 || i==11 || i==22 || i==28 || i==34) {
         continue;   // invalid interval
      }
      cout << Convert::base40ToKern(buffer, i+4*40);
      cout << "\t" << weights[i] << "\n";
   }
   cout << "*-\t*-\n";
}



//////////////////////////////
//
// getChordInformation --
//

void getChordInformation(Array<ArrayInt>& pitches, Array<int>& root,
     Array<double>& counte, HumdrumFile& datafile) {

   counte.setSize(datafile.getNumLines());
   counte.setSize(0);
   root.setSize(datafile.getNumLines());
   root.setSize(0);
   pitches.setSize(datafile.getNumLines());
   pitches.setSize(0);

   char buffer[1024] = {0};
   int i, j, k;
   int troot = -1;
   int tpitch = -1;
   double tcount = 0.0;
   Array<int> tpitches;
   tpitches.setSize(100);
   tpitches.setSize(0);
   for (i=0; i<datafile.getNumLines(); i++) {
      tpitches.setSize(0);
      troot = -1;
      tcount = 0.0;
      for (j=0; j<datafile[i].getFieldCount(); j++) {
         if (strcmp("**root", datafile[i].getExInterp(j)) == 0) {
            troot = Convert::kernToBase40(datafile[i][j]) % 40;
         } else if (strcmp("**count", datafile[i].getExInterp(j)) == 0) {
            tcount = strtod(datafile[i][j], NULL);
         } else if (strcmp("**kern", datafile[i].getExInterp(j)) == 0) {
            int notecount = datafile[i].getTokenCount(j);
            for (k=0; k<notecount; k++) {
               datafile[i].getToken(buffer, j, k);
               tpitch = Convert::kernToBase40(buffer);
               tpitches.append(tpitch);
            }
         }
      }

      if (troot < 0 || tcount <= 0.0 || tpitches.getSize() == 0) {
         continue;
      }
      root.append(troot);
      counte.append(tcount);
      pitches.append(tpitches);
   }


   if (debugQ) {
      cout << "**count\t**root\t**kern\n";
      for (i=0; i<counte.getSize(); i++) {
         cout << counte[i] << "\t";
         cout << Convert::base40ToKern(buffer, root[i]+3*40) << "\t";
         for (j=0; j<pitches[i].getSize(); j++) {
            cout << Convert::base40ToKern(buffer, pitches[i][j]);
            if (j < pitches[i].getSize() - 1) {
               cout << " ";
            }
         }
         cout << "\n";
      }
      cout << "*-\t*-\t*-\n";
   }

}



//////////////////////////////
//
// checkOptions -- validate and process command-line options.
//

void checkOptions(Options& opts, int argc, char* argv[]) {
   opts.define("v|verbose=b",          "monitor status");
   opts.define("t|tolerance=d:0.00000001", "fractional tolerance");
   opts.define("d|decay=d:0.8",        "decay factor for simplex side");
   opts.define("s|side|delta=s:10.0",  "initial simplex side width");
   opts.define("r|runs|trials=i:5000", "max num of test to do per iteration");
   opts.define("n|norm=b",             "normalize the final results");
   opts.define("nn=i:25",              "interval to which to normalize");
   opts.define("c|count=i:5",          "maximum repeat of algorithm");

   opts.define("debug=b",         "trace input parsing");
   opts.define("author=b",        "author of the program");
   opts.define("version=b",       "compilation information");
   opts.define("example=b",       "example usage");
   opts.define("h|help=b",        "short description");

   opts.process(argc, argv);

   // handle basic options:
   if (opts.getBoolean("author")) {
      cout << "Written by Craig Stuart Sapp, "
           << "craig@ccrma.stanford.edu, July 2002" << endl;
      exit(0);
   } else if (opts.getBoolean("version")) {
      cout << argv[0] << ", version: 28 July 2002" << endl;
      cout << "compiled: " << __DATE__ << endl;
      cout << MUSEINFO_VERSION << endl;
      exit(0);
   } else if (opts.getBoolean("help")) {
      usage(opts.getCommand());
      exit(0);
   } else if (opts.getBoolean("example")) {
      example();
      exit(0);
   }

   debugQ     = opts.getBoolean("debug");
   verboseQ   = opts.getBoolean("verbose");
   tolerance  = opts.getDouble("tolerance");
   trials     = opts.getInteger("trials");
   initdelta  = opts.getDouble("delta");
   decay      = opts.getDouble("decay");
   sidedelta  = initdelta;
   normQ      = opts.getBoolean("norm");
   norm       = opts.getInteger("nn");
   repeatcount= opts.getInteger("count");
}



//////////////////////////////
//
// example -- example usage of the maxent program
//

void example(void) {
   cout <<
   "                                                                        \n"
   << endl;
}



//////////////////////////////
//
// usage -- gives the usage statement for the quality program
//

void usage(const char* command) {
   cout <<
   "                                                                        \n"
   << endl;
}



/////////////////////////////////////////////////////////////////////////////
//
// downhill simplex algorithm functions
//


#define  SWAP(a, b)      {swap = (a); (a) = (b); (b) = swap;}
#define  TINY 1.0e-10


//////////////////////////////
//
// amoeba -- Downhill Simlpex Method
//   J.A. Nelder and R. Mead,  1965, Computer Journal vol. 7, pp. 308-313.
// simplex = N+1 vertices of a simplex in N-dimensional space
//           which is the starting point of the search algorithm.
// functionEvaluation = evaluation of function at the points in the simplex.
// evalCount = number of function evaluations done.
// ftol      = fractional convergence tolerance.
// maxtests  = maximum number of function evaluations allowed in algorithm.
//

void amoeba(Matrix<double>& simplex, double ftol, 
      double(*testFunction)(Array<double>&), int& evalCount,
      int maxtests) {

   int N = simplex.getColumnCount();
   int M = N + 1;

   int i, j;
   int iworst;     // simplex vertex with the highest test function evaluation
   int inextworst; // simplex vertex with the second highest function evaluation
   int ibest;      // simplex vertex with the lowest test function evaluation
   double rtol;
   double swap;
   double ysave;
   double ytry;
   Array<double> simplexSum(N);
   Array<double> functionEvaluation(M);
   Array<double> testweights;

   for (j=0; j<M; j++) {
      simplex.getRow(testweights, j);
      functionEvaluation[j] = (*testFunction)(testweights);
   }


   evalCount = 0;

   for (j=0; j<N; j++) {
      simplexSum[j] = 0.0;
      for (i=0; i<M; i++) {
         simplexSum[j] += simplex.cell(i, j);
      }
   }

   while (evalCount < maxtests) {

      if (verboseQ) {
         cout << "Eval count: " << evalCount << endl;
      }

      // Find the highest (worst), next-highest, and lowest (best) points
      // in the simplex.
      iworst = functionEvaluation[0] > functionEvaluation[1] ?
         (inextworst = 1, 0) : (inextworst = 0, 1);
      ibest = 0;
      for (i=0; i<M; i++) {
         if (functionEvaluation[i] <= functionEvaluation[ibest]) {
            ibest = i;
         }
         if (functionEvaluation[i] > functionEvaluation[iworst]) {
            inextworst = iworst;
            iworst = i;
         } else if ((functionEvaluation[i] > functionEvaluation[inextworst]) && 
               (i != iworst)) {
            inextworst = i;
         }
      }

 
      // Computer the fractional range from the highest to lowest and
      // return if satisfactory.
      rtol = 2.0*fabs(functionEvaluation[iworst] - functionEvaluation[ibest]) / 
            (fabs(functionEvaluation[iworst]) +
             fabs(functionEvaluation[ibest]) + TINY);
      if (rtol < ftol) {
         SWAP(functionEvaluation[0], functionEvaluation[ibest]);
         for (i=0; i<N; i++) {
            SWAP(simplex.cell(0,i), simplex.cell(ibest,i));
         }
         break;
      }


      if (evalCount >= maxtests) {
         // cout << "Process Stop : evaluation count " << evalCount 
         //      << " reached" << endl;
         SWAP(functionEvaluation[0], functionEvaluation[ibest]);
         for (i=0; i<N; i++) {
            SWAP(simplex.cell(0,i), simplex.cell(ibest,i));
         }
         break;
      }
      evalCount += 2;


      // Begin a new iteration.  First extrapolate by a factor -1 through
      // the face of the simplex face across from the high point
      // (reflect the simplex from the high point).

      ytry = amoebaTry(simplex, functionEvaluation, simplexSum,
            testFunction, iworst, -1.0);
      if (verboseQ) {
         cout << "Current state: " << ytry << endl;
      }

      if (ytry <= functionEvaluation[ibest]) {
         ytry = amoebaTry(simplex, functionEvaluation, simplexSum,
            testFunction, iworst, 2.0);
      } else if (ytry >= functionEvaluation[inextworst]) {
         ysave = functionEvaluation[iworst];
         ytry = amoebaTry(simplex, functionEvaluation, simplexSum, 
               testFunction, iworst, 0.5);
         if (ytry >= ysave) {
            for (i=0; i<M; i++) {
               if (i != ibest) {
                  for (j=0; j<N; j++) {
                     simplex.cell(i,j) = 0.5 * (simplex.cell(i,j) + 
                           simplex.cell(ibest,j));
                     simplexSum[j] = simplex.cell(i,j);
                  }
                  functionEvaluation[i] = (*testFunction)(simplexSum);
               }
            }
            evalCount += N;  
            for (j=0; j<N; j++) {
               simplexSum[j] = 0.0;
               for (i = 0; i < M; i++) {
                  simplexSum[j] += simplex.cell(i,j);
               }
            }
         }
      } else { 
         evalCount = evalCount - 1; 
      }
   }
   
}

   

//////////////////////////////
//
// amoebaTry -- search the face across from the high point in the
//      simplex and replace the high point if the new point is better.
// simplex            = vertices of a simplex in N-dimensional space.
// functionEvaluation = The evaluation of the function being optimized, at
//                      each point in the simplex.
// simplexSum         = a sum of the points in the simplex.
// frac               = fractional position on the opposite side of the
//                      worst vertex.
//

double amoebaTry(Matrix<double>& simplex, Array<double>& functionEvaluation,
      Array<double>& simplexSum, double(*testFunction)(Array<double>&),
      int worstPoint, double fac) {

   int N = simplex.getColumnCount();   // search-space dimensions
   Array<double> testPoint(N);     // test point to replace worstPoint
   double evaluation;              // function evaluation at test point
   double fac1 = (1.0 - fac) / N;
   double fac2 = fac1 - fac;
   int i;

   for (i=0; i<N; i++) {
      testPoint[i] = simplexSum[i] * fac1 - simplex.cell(worstPoint,i) * fac2;
   }

   evaluation = (*testFunction)(testPoint);
   if (evaluation < functionEvaluation[worstPoint]) {
      functionEvaluation[worstPoint] = evaluation;
      for (i=0; i<N; i++) {
         simplexSum[i] += testPoint[i] - simplex.cell(worstPoint,i);
         simplex.cell(worstPoint,i) = testPoint[i];
      }
   }

   return evaluation;
}



//////////////////////////////
//
// testFunction -- input algorithm for the Downhill Simplex Method.
//

double testFunction(Array& weights) {
   return getErrors(weights, pitches, root, counte);
}



//////////////////////////////
//
// runDownhillSimplexMethod -- run the downhill simplex method
//     and return the best weights found by the method.
//

double runDownhillSimplexMethod(Array<double>& returnedWeights, 
      Array<double>& initialWeights) {

   Matrix<double> simplex;
   generateSimplex(simplex, initialWeights, sidedelta);

   double ftol      = tolerance;
   int    evalCount = 0;
   int    maxtests  = trials;

   amoeba(simplex, ftol, testFunction, evalCount, maxtests);

   Array<double> testweights;
   simplex.getRow(testweights, 0);

   double best = getErrors(testweights, pitches, root, counte);
   double test;
   int ibest = 0;
   int i;
   for (i=1; i<simplex.getColumnCount(); i++) {
      simplex.getRow(testweights, i);
      test = getErrors(testweights, pitches, root, counte);
      if (test <= best) {
         ibest = i;
         best = test;
      }
   }
  
   simplex.getRow(returnedWeights, ibest);

   return best;
}



//////////////////////////////
//
// generateSimplex --
//

void generateSimplex(Matrix<double>& simplex, Array<double>& initialWeights, 
      double delta) {
   int N = initialWeights.getSize();
   int M = N + 1;
   simplex.setSize(M, N);
   int i;

   simplex.setRow(0, initialWeights);

   for (i=1; i<M; i++) {
      simplex.setRow(i, initialWeights);
      simplex.cell(i,i-1) += delta;
   }

}


// md5sum: 1eb977ea2d1339a3171c1b03ed039d40 iwsimplex.cpp [20050403]