package edu.uky.ai.sl;

import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.util.ArrayList;
import java.util.Collection;
import java.util.Comparator;
import java.util.function.Function;

import edu.uky.ai.SearchBudget;
import edu.uky.ai.data.LabeledDataPoint;
import edu.uky.ai.data.LabeledDataSet;
import edu.uky.ai.io.DummyOutputStream;
import edu.uky.ai.util.Arguments;
import edu.uky.ai.util.Table;
import edu.uky.ai.util.Utilities;

/**
 * The entry point for the supervised learning application.
 * 
 * @author Stephen G. Ware
 */
public class Main {
	
	private static final String USAGE = "Usage: java -jar sl.jar -a <learners> -p <data> [-ltl <nodes>] [-ctl <millis>] [-o <file>]" +
		"\n  <learners>    is one or more JAR files containing an instance of " + Learner.class.getName() +
		"\n  <data>        is one or more data files in CSV format" +
		"\n  -ltl <millis> is the optional maximum number of milliseconds a learner may take to build a model" +
		"\n  -ctl <nodes>  is the optional maximum number of milliseconds a model may take to classify a data point" +
		"\n  <file>        is the file to which output will be written in HTML format";

	/**
	 * Launches the supervised learning application according to its command
	 * line arguments.
	 * 
	 * @param args the command line arguments
	 * @throws Exception if any uncaught exceptions are thrown
	 */
	public static void main(String[] args) throws Exception {
		if(args.length < 2) {
			System.out.println(USAGE);
			System.exit(1);
		}
		Arguments arguments = new Arguments(args);
		ArrayList<Learner> learners = new ArrayList<>();
		for(String url : arguments.getValues("-a")) {
			System.out.println("Loading agent \"" + url + "\"...");
			learners.add(Utilities.loadFromJARFile(Learner.class, new File(url)));
		}
		ArrayList<LabeledDataSet> data = new ArrayList<>();
		for(String url : arguments.getValues("-p")) {
			System.out.println("Reading data set \"" + url + "\"...");
			data.add(LabeledDataSet.read(new File(url)));
		}		
		long learningTimeLimit = SearchBudget.INFINITE_TIME;
		if(arguments.containsKey("-ltl"))
			learningTimeLimit = Long.parseLong(arguments.getValue("-ltl"));
		long classificationTimeLimit = SearchBudget.INFINITE_TIME;
		if(arguments.containsKey("-ctl"))
			classificationTimeLimit = Long.parseLong(arguments.getValue("-ctl"));
		final OutputStream output;
		if(arguments.containsKey("-o"))
			output = new BufferedOutputStream(new FileOutputStream(arguments.getValue("-o")));
		else
			output = new DummyOutputStream();
		try {
			if(learners.size() == 1 && data.size() == 1)
				learn(learners.get(0), data.get(0), learningTimeLimit, classificationTimeLimit, output);
			else
				benchmark(learners.toArray(new Learner[learners.size()]), data.toArray(new LabeledDataSet[data.size()]), learningTimeLimit, classificationTimeLimit, output);
		}
		finally {
			output.close();
		}
	}

	/**
	 * One of several folds of the data set used to train and then evaluate
	 * a {@link Learner}.
	 * 
	 * @author Stephen G. Ware
	 */
	private static final class Fold {
		
		/** The original data set */
		public final LabeledDataSet data;
		
		/** The fold number */
		public final int index;
		
		/** The portion of the data used for training */
		public final LabeledDataSet training;
		
		/** The portion of the data used for testing */
		public final LabeledDataSet testing;
		
		/**
		 * Constructs a new fold with the given parameters.
		 * 
		 * @param data the original data set
		 * @param index the fold number
		 * @param training the data used for training
		 * @param testing the data used for testing
		 */
		public Fold(LabeledDataSet data, int index, Collection<LabeledDataPoint> training, Collection<LabeledDataPoint> testing) {
			this.data = data;
			this.index = index;
			this.training = new LabeledDataSet("training", data.features, data.labels, training);
			this.testing = new LabeledDataSet("testing", data.features, data.labels, testing);
		}
	}
	
	/**
	 * Given a collection of data sets, this method divides each one into 10
	 * folds for 10-fold cross-validation.
	 * 
	 * @param data the collection of data sets
	 * @return an array of arrays, each of length 10, representing its folds
	 */
	static final Fold[][] fold(Collection<LabeledDataSet> data) {
		Fold[][] sets = new Fold[data.size()][];
		int index = 0;
		for(LabeledDataSet set : data)
			sets[index++] = fold(set);
		return sets;
	}
	
	/**
	 * Given a data set, this method folds it 10 ways.
	 * 
	 * @param data the data set
	 * @return an array of 10 folds
	 */
	static final Fold[] fold(LabeledDataSet data) {
		ArrayList<Fold> folds = new ArrayList<>();
		ArrayList<LabeledDataPoint> training = new ArrayList<>();
		ArrayList<LabeledDataPoint> testing = new ArrayList<>();
		for(int fold=0; fold<10; fold++) {
			training.clear();
			testing.clear();
			int test = fold;
			for(int i=0; i<data.points.size(); i++) {
				if(i == test) {
					testing.add(data.points.get(i));
					test += 10;
				}
				else
					training.add(data.points.get(i));
			}
			folds.add(new Fold(data, fold, training, testing));
		}
		return folds.toArray(new Fold[folds.size()]);
	}
	
	/**
	 * Evaluates a given {@link Learner} on a
	 * {@link edu.uky.ai.data.LabeledDataSet labeled data set} using 10-fold
	 * cross-validation.
	 * 
	 * @param learner the learner to evaluate
	 * @param data the data on which to training and evaluate the learner
	 * @param learningTimeLimit the maximum number of milliseconds the learner
	 * may spend learning a {@link Model}
	 * @param classificationTimeLimit the maximum number of milliseconds a
	 * {@link Model} may take to classify an unknown data point
	 * @param output the output stream to which a summary of results will be
	 * written
	 * @throws IOException if an error occurs while writing to the output
	 * stream
	 */
	public static void learn(Learner learner, LabeledDataSet data, long learningTimeLimit, long classificationTimeLimit, OutputStream output) throws IOException {
		Summary summary = evaluate(learner, fold(data), learningTimeLimit, classificationTimeLimit);
		System.out.println(
			"Learner:   " + summary.learner.name + "\n" +
			"Data:      " + data.name + "\n" +
			"Correct    " + summary.correct + "\n" +
			"Accuracy:  " + Utilities.percent(summary.accuracy) + "\n" +
			"Time (ms): " + summary.time
		);
		if(summary.correct > 0) {
			try(ObjectOutputStream out = new ObjectOutputStream(output)) {
				out.writeObject(summary.model);
				System.out.println("Best model serialized to output.");
			}
			catch(Exception ex) {
				ex.printStackTrace();
			}
		}
	}
	
	/**
	 * Compares many {@link Learner}s on many
	 * {@link edu.uky.ai.data.LabeledDataSet labeled data sets} using 10-fold
	 * cross-validation.
	 * 
	 * @param learners the learners to evaluate
	 * @param data the data sets on which to train and evaluate the learners
	 * @param learningTimeLimit the maximum number of milliseconds the learner
	 * may spend learning a {@link Model}
	 * @param classificationTimeLimit the maximum number of milliseconds a
	 * {@link Model} may take to classify an unknown data point
	 * @param output the output stream to which a summary of results will be
	 * written
	 * @throws IOException if an error occurs while writing to the output
	 * stream
	 */
	public static void benchmark(Learner[] learners, LabeledDataSet[] data, long learningTimeLimit, long classificationTimeLimit, OutputStream output) throws IOException {
		Fold[][] sets = new Fold[data.length][];
		for(int i=0; i<data.length; i++) {
			System.out.println("Folding dataset \"" + data[i].name + "\"...");
			sets[i] = fold(data[i]);
		}
		Table results = new Table(data, learners);
		for(int i=0; i<data.length; i++)
			for(Learner learner : learners)
				results.getCell(data[i], learner).value = evaluate(learner, sets[i], learningTimeLimit, classificationTimeLimit);
		results = results.sortByColumn(BEST_LEARNERS);
		Table average = results.transform(ACCURACY).addAverageColumn().addAverageRow().transform(TWO_DECIMAL_PLACES);
		System.out.println("Results:\n" + average);
		Writer writer = new OutputStreamWriter(output);
		writer.append("<html>\n<head>\n<title>Supervised Learning Benchmark Results</title>");
		writer.append("\n<style>\ntable { border-collapse: collapse; }\ntable, tr, th, td { border: 1px solid black; }\ntr:nth-child(odd) { background-color: lightgray; }\nth { font-weight: bold; }\ntd { text-align: right; }\n</style>");
		writer.append("\n</head>\n<body>\n\n<h1>Supervised Learning Benchmark Results</h1>");
		writer.append("\n\n<h2>Average Accuracy</h2>\n" + average.toHTML());
		writer.append("\n\n<h2>Total Correct Classifications</h2>\n" + results.transform(CORRECT).addTotalColumn().addTotalRow().transform(CORRECT).toHTML());
		writer.append("\n\n<h2>Total Time Spent Learning and Classifying</h2>\n" + results.transform(TIME).addTotalColumn().addTotalRow().transform(TIME).toHTML());
		writer.append("\n\n</body>\n<html>");
		writer.flush();
	}
	
	/**
	 * Evaluates a single learner on a single, 10-folded labeled data set.
	 * 
	 * @param learner the learner
	 * @param folds the 10 folds of the data
	 * @param learningTimeLimit the maximum number of milliseconds the learner
	 * may spend learning a {@link Model}
	 * @param classificationTimeLimit the maximum number of milliseconds a
	 * {@link Model} may take to classify an unknown data point
	 * @return a summary of the results
	 */
	private static final Summary evaluate(Learner learner, Fold[] folds, long learningTimeLimit, long classificationTimeLimit) {
		System.out.println("Evaluating agent \"" + learner.name + "\" on dataset \"" + folds[0].data.name + "\":");
		Result[] results = new Result[folds.length];
		for(int i=0; i<results.length; i++) {
			System.out.print("  fold " + (folds[i].index + 1) + ": ");
			results[i] = learner.evaluate(folds[i].training, learningTimeLimit, folds[i].testing, classificationTimeLimit);
			if(results[i].reason.equals(Result.SUCCESS))
				System.out.println(Utilities.percent(results[i].accuracy));
			else
				System.out.println(results[i].reason);
		}
		return new Summary(learner, folds[0].data, results);
	}
	
	private static final Comparator<Table.Column> BEST_LEARNERS = new Comparator<Table.Column>() {
		@Override
		public int compare(Table.Column column1, Table.Column column2) {
			Number difference = Utilities.subtract(column2.average(ACCURACY), column1.average(ACCURACY));
			if(difference.doubleValue() == 0d)
				difference = Utilities.subtract(column2.sum(CORRECT), column1.sum(CORRECT));
			if(difference.doubleValue() == 0d)
				difference = Utilities.subtract(column1.sum(TIME), column2.sum(TIME));
			if(difference.doubleValue() == 0d)
				return 0;
			else if(difference.doubleValue() < 0d)
				return -1;
			else
				return 1;
		}
	};
	
	private static final Function<Object, Object> ACCURACY = new Function<Object, Object>() {
		@Override
		public Object apply(Object object) {
			if(object instanceof Summary)
				return ((Summary) object).accuracy;
			else
				return object;
		}
	};
	
	private static final Function<Object, Object> CORRECT = new Function<Object, Object>() {
		@Override
		public Object apply(Object object) {
			if(object instanceof Summary)
				return ((Summary) object).correct;
			else
				return object;
		}
	};
	
	private static final Function<Object, Object> TIME = new Function<Object, Object>() {
		@Override
		public Object apply(Object object) {
			if(object instanceof Summary)
				return ((Summary) object).time;
			else
				return object;
		}
	};
	
	private static final Function<Object, String> TWO_DECIMAL_PLACES = new Function<Object, String>() {
		@Override
		public String apply(Object object) {
			if(object instanceof Double)
				return String.format("%.2f%%", ((Double) object) * 100);
			else
				return object.toString();
		}
	};
}