package edu.uky.ai.rl;

import java.awt.GraphicsEnvironment;
import java.io.BufferedOutputStream;
import java.io.File;
import java.io.FileOutputStream;
import java.io.IOException;
import java.io.ObjectOutputStream;
import java.io.OutputStream;
import java.io.OutputStreamWriter;
import java.io.Writer;
import java.text.DecimalFormat;
import java.util.ArrayList;
import java.util.Comparator;
import java.util.function.Function;

import edu.uky.ai.SearchBudget;
import edu.uky.ai.io.DummyOutputStream;
import edu.uky.ai.util.Arguments;
import edu.uky.ai.util.Table;
import edu.uky.ai.util.Utilities;

/**
 * The entry point for the reinforcement learning application.
 * 
 * @author Stephen G. Ware
 */
public class Main {
	
	private static final String USAGE = "Usage: java -jar rl.jar -a <learners> -p <problems> [-lnl <nodes>] [-ltl <millis>] [-enl <nodes>] [-etl <millis>] [-d <millis>] [-o <file>]" +
		"\n  <learners>    is one or more JAR files containing an instance of " + Learner.class.getName() +
		"\n  <problems>    is one or more stochastic process files" +
		"\n  -lnl <nodes>  is the optional maximum number of state transitions each learner may perform while learning a policy" +
		"\n  -ltl <millis> is the optional maximum number of milliseconds a learner may take while learning a policy" +
		"\n  -enl <nodes>  is the optional maximum number of state transitions each learner may perform while evaluating a policy" +
		"\n  -etl <millis> is the optional maximum number of milliseconds a learner may take while evaluating a policy" +
		"\n  -d <millis>   is the delay between frames of the GUI (defaults to " + Settings.DEFAULT_GUI_DELAY + ")" +
		"\n  <file>        is the file to which output will be written in HTML format";
	
	/**
	 * Launches the reinforcement learning application according to its command
	 * line arguments.
	 * 
	 * @param args the command line arguments
	 * @throws Exception if any uncaught exceptions are thrown
	 */
	public static void main(String[] args) throws Exception {
		if(args.length < 4) {
			System.out.println(USAGE);
			System.exit(1);
		}
		Arguments arguments = new Arguments(args);
		ArrayList<Learner> learners = new ArrayList<>();
		for(String url : arguments.getValues("-a"))
			learners.add(Utilities.loadFromJARFile(Learner.class, new File(url)));
		ArrayList<StochasticProcess> processes = new ArrayList<>();
		ProcessParser parser = new ProcessParser();
		for(String url : arguments.getValues("-p"))
			processes.add(parser.parse(new File(url)));
		int learningNodeLimit = SearchBudget.INFINITE_OPERATIONS;
		if(arguments.containsKey("-lnl"))
			learningNodeLimit = Integer.parseInt(arguments.getValue("-lnl"));
		long learningTimeLimit = SearchBudget.INFINITE_TIME;
		if(arguments.containsKey("-ltl"))
			learningTimeLimit = Long.parseLong(arguments.getValue("-ltl"));
		int evaluationNodeLimit = SearchBudget.INFINITE_OPERATIONS;
		if(arguments.containsKey("-enl"))
			evaluationNodeLimit = Integer.parseInt(arguments.getValue("-enl"));
		long evaluationTimeLimit = SearchBudget.INFINITE_TIME;
		if(arguments.containsKey("-etl"))
			evaluationTimeLimit = Long.parseLong(arguments.getValue("-etl"));
		int delay = Settings.DEFAULT_GUI_DELAY;
		if(arguments.containsKey("-d"))
			delay = Integer.parseInt(arguments.getValue("-d"));
		final OutputStream output;
		if(arguments.containsKey("-o"))
			output = new BufferedOutputStream(new FileOutputStream(arguments.getValue("-o")));
		else
			output = new DummyOutputStream();
		try {
			if(learners.size() == 1 && processes.size() == 1)
				learn(learners.get(0), processes.get(0), learningNodeLimit, learningTimeLimit, evaluationNodeLimit, evaluationTimeLimit, delay, output);
			else
				benchmark(learners.toArray(new Learner[learners.size()]), processes.toArray(new StochasticProcess[processes.size()]), learningNodeLimit, learningTimeLimit, evaluationNodeLimit, evaluationTimeLimit, delay, output);
		}
		finally {
			output.close();
		}
	}
	
	/**
	 * Uses a given learner to find a policy for a given stochastic process.
	 * 
	 * @param learner the learner to use
	 * @param process the stochastic process for which a policy should be found
	 * @param learningNodeLimit the maximum number of calls to {@link StochasticProcess#transition(State, Action)} that can be made while learning the policy
	 * @param learningTimeLimit the maximum number of milliseconds that can be spent learning the policy
	 * @param evaluationNodeLimit the maximum number of calls to {@link StochasticProcess#transition(State, Action)} that can be made while evaluating the policy
	 * @param evaluationTimeLimit the maximum number of milliseconds that can be spent evaluating the policy
	 * @param delay the delay between frame of the GUI when visualizing the learning and evaluation processes (when set to 0, no GUI will be shown)
	 * @param output where the learned policy will be serialized (if one is found)
	 * @throws IOException if a problem occurs while writing the output
	 */
	public static void learn(Learner learner, StochasticProcess process, int learningNodeLimit, long learningTimeLimit, int evaluationNodeLimit, long evaluationTimeLimit, int delay, OutputStream output) throws IOException {
		ProcessFrame frame = null;
		if(!GraphicsEnvironment.isHeadless() && delay != 0)
			frame = new ProcessFrame(delay);
		Result result = learner.learn(process, learningNodeLimit, learningTimeLimit, evaluationNodeLimit, evaluationTimeLimit, frame);
		System.out.println(
			"Learner:     " + result.learner.name + "\n" +
			"Task:        " + result.process.name + "\n" +
			"Result:      " + result.reason + "\n" +
			"Transitions: " + result.transitions + "\n" +
			"Time (ms):   " + result.time + "\n" +
			"Score:       " + result.score
		);
		if(result.success) {
			try(ObjectOutputStream out = new ObjectOutputStream(output)) {
				out.writeObject(result.policy);
				System.out.println("Policy serialized to output.");
			}
			catch(Exception ex) {
				ex.printStackTrace();
			}
		}
	}
	
	/**
	 * Compares the performance of one or more learners on one or more
	 * stochastic processes.
	 * 
	 * @param learners the learners to compare
	 * @param processes the stochastic processes on which the learners will be compared
	 * @param learningNodeLimit the maximum number of calls to {@link StochasticProcess#transition(State, Action)} that can be made while learning a policy
	 * @param learningTimeLimit the maximum number of milliseconds that can be spent learning a policy
	 * @param evaluationNodeLimit the maximum number of calls to {@link StochasticProcess#transition(State, Action)} that can be made while evaluating a policy
	 * @param evaluationTimeLimit the maximum number of milliseconds that can be spent evaluating a policy
	 * @param delay the delay between frame of the GUI when visualizing the learning and evaluation processes (when set to 0, no GUI will be shown)
	 * @param output where the output will be written in HTML format
	 * @throws IOException if a problem occurs while writing the output
	 */
	public static void benchmark(Learner[] learners, StochasticProcess[] processes, int learningNodeLimit, long learningTimeLimit, int evaluationNodeLimit, long evaluationTimeLimit, int delay, OutputStream output) throws IOException {
		Table results = new Table(processes, learners);
		ProcessFrame frame = null;
		if(!GraphicsEnvironment.isHeadless() && delay != 0)
			frame = new ProcessFrame(delay);
		for(StochasticProcess process : processes) {
			for(Learner learner : learners) {
				System.out.print("Agent \"" + learner + "\" learning task \"" + process + "\": ");
				Result result = learner.learn(process, learningNodeLimit, learningTimeLimit, evaluationNodeLimit, evaluationTimeLimit, frame);
				if(result.success)
					System.out.println("score " + result.score);
				else
					System.out.println(result.reason);
				results.getCell(process, learner).value = result;
			}
		}
		results = results.sortByColumn(BEST_LEARNERS);
		Table average = results.transform(SCORE).addAverageColumn().addAverageRow().transform(TWO_DECIMAL_PLACES);
		System.out.println("Results:\n" + average);
		Writer writer = new OutputStreamWriter(output);
		writer.append("<html>\n<head>\n<title>Reinforcement Learning Benchmark Results</title>");
		writer.append("\n<style>\ntable { border-collapse: collapse; }\ntable, tr, th, td { border: 1px solid black; }\ntr:nth-child(odd) { background-color: lightgray; }\nth { font-weight: bold; }\ntd { text-align: right; }\n</style>");
		writer.append("\n</head>\n<body>\n\n<h1>Reinforcement Learning Benchmark Results</h1>");
		writer.append("\n\n<h2>Average Score</h2>\n" + average.toHTML());
		writer.append("\n\n<h2>Total Score</h2>\n" + results.transform(SCORE).addTotalColumn().addTotalRow().transform(TWO_DECIMAL_PLACES).toHTML());
		writer.append("\n\n<h2>Transitions Used During Learning</h2>\n" + results.transform(TRANSITIONS).addTotalColumn().addTotalRow().transform(TWO_DECIMAL_PLACES).toHTML());
		writer.append("\n\n<h2>Time Spent Learning</h2>\n" + results.transform(TIME).addTotalColumn().addTotalRow().transform(TWO_DECIMAL_PLACES).toHTML());
		writer.append("\n\n</body>\n<html>");
		writer.flush();
	}
	
	private static final Comparator<Table.Column> BEST_LEARNERS = new Comparator<Table.Column>() {
		@Override
		public int compare(Table.Column column1, Table.Column column2) {
			Number difference = Utilities.subtract(column2.average(SCORE), column1.average(SCORE));
			if(difference.doubleValue() == 0d)
				difference = Utilities.subtract(column2.sum(SCORE), column1.sum(SCORE));
			if(difference.doubleValue() == 0d)
				difference = Utilities.subtract(column1.sum(TRANSITIONS), column2.sum(TRANSITIONS));
			if(difference.doubleValue() == 0d)
				difference = Utilities.subtract(column1.sum(TIME), column2.sum(TIME));
			if(difference.doubleValue() == 0d)
				return 0;
			else if(difference.doubleValue() < 0d)
				return -1;
			else
				return 1;
		}
	};
	
	private static final Function<Object, Object> SCORE = new Function<Object, Object>() {
		@Override
		public Object apply(Object object) {
			if(object instanceof Result)
				return ((Result) object).score;
			else
				return object;
		}
	};
	
	private static final Function<Object, Object> TRANSITIONS = new Function<Object, Object>() {
		@Override
		public Object apply(Object object) {
			if(object instanceof Result)
				return ((Result) object).transitions;
			else
				return object;
		}
	};
	
	private static final Function<Object, Object> TIME = new Function<Object, Object>() {
		@Override
		public Object apply(Object object) {
			if(object instanceof Result)
				return ((Result) object).time;
			else
				return object;
		}
	};
	
	private static final Function<Object, String> TWO_DECIMAL_PLACES = new Function<Object, String>() {
		@Override
		public String apply(Object object) {
			if(object instanceof Number)
				return new DecimalFormat("0.00").format((Number) object);
			else
				return object.toString();
		}
	};
}
