package edu.uky.ai.sl;

import java.util.concurrent.ExecutionException;
import java.util.concurrent.ExecutorService;
import java.util.concurrent.Executors;
import java.util.concurrent.Future;
import java.util.concurrent.TimeUnit;
import java.util.concurrent.TimeoutException;

import edu.uky.ai.SearchBudget;
import edu.uky.ai.TimeBudgetExceededException;
import edu.uky.ai.data.DataPoint;
import edu.uky.ai.data.DataSet;
import edu.uky.ai.data.LabeledDataSet;
import edu.uky.ai.data.Value;

/**
 * A learner is an algorithm which solves a supervised machine learning problem
 * by producing a {@link Model} that can classify new observations.
 * 
 * @author Stephen G. Ware
 */
public abstract class Learner {

	/** The learner's name */
	public final String name;
	
	/**
	 * Constructs a new learner with the given name.
	 * 
	 * @param name the name
	 */
	public Learner(String name) {
		this.name = name;
	}
	
	@Override
	public String toString() {
		return name;
	}
	
	/**
	 * Evaluates this learner by building a model on the given training data
	 * using the {@link Learner#learn(LabeledDataSet, SearchBudget) learn}
	 * method and then evaluating that model on the given testing data.
	 * 
	 * @param training the data on which to build the model
	 * @param trainingTime the maximum number of milliseconds allowed to build
	 * the model
	 * @param testing the data on which to test the model's accuracy
	 * @param classificationTime the maximum number of milliseconds the model
	 * may take to evaluate a specific test data point
	 * @return a {@link Result} object summarizing the evaluation
	 * @throws TimeBudgetExceededException if the training time limit is
	 * exceeded
	 * @throws IllegalStateException if the model returns null for any data
	 * point in the test set
	 */
	public final Result evaluate(LabeledDataSet training, long trainingTime, LabeledDataSet testing, long classificationTime) {
		Model model = null;
		long train = 0;
		int correct = 0;
		long test = 0;
		String reason = Result.SUCCESS;
		Object[] extra = new Object[512];
		extra[0] = null;
		try {
			long start = System.currentTimeMillis();
			model = train(training, trainingTime);
			train = System.currentTimeMillis() - start;
			if(model != null) {
				start = System.currentTimeMillis();
				correct = test(model, testing, classificationTime);
				test = System.currentTimeMillis() - start;
			}
		}
		catch(OutOfMemoryError r) {
			extra = null;
			System.gc();
			reason = "out of memory";
		}
		catch(Throwable e) {
			e.printStackTrace();
			reason = e.getMessage();
			if(reason == null || reason.isEmpty())
				reason = e.toString();
		}
		return new Result(this, training, model, train, testing, correct, test, reason);
	}
	
	/**
	 * Trains a model on the given training data.
	 * 
	 * @param data the data on which to train the model
	 * @param trainingTime the maximum milliseconds training may take
	 * @return a trained model
	 * @throws InterruptedException if training is interrupted
	 * @throws TimeoutException if the training time limit is exceeded
	 */
	private final Model train(LabeledDataSet data, long trainingTime) throws Throwable {
		ExecutorService executor = Executors.newFixedThreadPool(1);
		Future<Model> future = executor.submit(() -> learn(data, new SearchBudget(SearchBudget.INFINITE_OPERATIONS, trainingTime)));
		try {
			return future.get(trainingTime == SearchBudget.INFINITE_TIME ? Long.MAX_VALUE : trainingTime, TimeUnit.MILLISECONDS);
		}
		catch(InterruptedException e) {
			throw new Exception("training interrupted");
		}
		catch(TimeBudgetExceededException e) {
			throw new Exception("training time exceeded");
		}
		catch(TimeoutException e) {
			throw new Exception("training time exceeded");
		}
		catch(ExecutionException e) {
			throw e.getCause();
		}
		finally {
			executor.shutdown();
		}
	}
	
	/**
	 * Tests a model on the given test data.
	 * 
	 * @param model the model to test
	 * @param data the data on which to test the model
	 * @param classificationTime the maximum number of milliseconds the model
	 * may take to classify any given test data point
	 * @return the number of test data points classified correctly
	 * @throws IllegalStateException if the model returns null when classifying
	 * any test data point
	 */
	private final int test(Model model, LabeledDataSet data, long classificationTime) throws Throwable {
		DataSet unlabeled = data.removeLabels();
		int correct = 0;
		for(int i=0; i<data.points.size(); i++) {
			Value classification = classify(model, unlabeled.points.get(i), classificationTime);
			if(classification == null)
				throw new IllegalStateException("Model returned null.");
			else if(data.points.get(i).label.equals(classification))
				correct++;
		}
		return correct;
	}
	
	/**
	 * Uses the given model to classify a single given data point.
	 * 
	 * @param model the model used to classify
	 * @param point the data point to classify
	 * @param classificationTime the maximum number of milliseconds the model
	 * may take to classify the point
	 * @return the model's classification
	 * @throws InterruptedException if classification is interrupted
	 * @throws TimeBudgetExceededException if the classification time limit is
	 * exceeded
	 */
	private final Value classify(Model model, DataPoint point, long classificationTime) throws Throwable {
		ExecutorService executor = Executors.newFixedThreadPool(1);
		Future<Value> future = executor.submit(() -> model.classify(point, new SearchBudget(SearchBudget.INFINITE_OPERATIONS, classificationTime)));
		try {
			return future.get(classificationTime == SearchBudget.INFINITE_TIME ? Long.MAX_VALUE : classificationTime, TimeUnit.MILLISECONDS);
		}
		catch(InterruptedException e) {
			throw new Exception("classification interrupted");
		}
		catch(TimeBudgetExceededException e) {
			throw new Exception("classification time exceeded");
		}
		catch(TimeoutException e) {
			throw new Exception("classification time exceeded");
		}
		catch(ExecutionException e) {
			throw e.getCause();
		}
		finally {
			executor.shutdown();
		}
	}
	
	/**
	 * Given {@link edu.uky.ai.data.LabeledDataSet labeled training data set},
	 * this method returns a learned model that can be used to classify new
	 * {@link edu.uky.ai.data.DataPoint data points}.
	 * 
	 * @param data the training data
	 * @param budget the time budget allowed for training
	 * @return the learned model
	 */
	public abstract Model learn(LabeledDataSet data, SearchBudget budget);
}