package edu.uky.ai.util;

import java.io.StringWriter;
import java.util.ArrayList;
import java.util.Collections;
import java.util.Comparator;
import java.util.HashMap;
import java.util.Iterator;
import java.util.function.BiConsumer;
import java.util.function.Function;
import java.util.function.Predicate;

/**
 * Logic for displaying simple 2-dimensional tables of data.
 * 
 * @author Stephen G. Ware
 */
public class Table implements Cloneable {
	
	/**
	 * The abstract parent of {@link Row} and {@link Column}.
	 * 
	 * @author Stephen G. Ware
	 */
	public static abstract class Sequence implements Iterable<Cell> {
		
		/** The table this sequence belongs to */
		public final Table table;
		
		/** The row or column number of this sequence */
		public final int number;
		
		/** The row or column lable of this sequence */
		public final Object label;
		
		/** The cells that make up this row or column */
		public final ImmutableArray<Cell> cells;
		
		private Sequence(Table table, int number, Object label, Cell[] cells) {
			this.table = table;
			this.number = number;
			this.label = label;
			this.cells = new ImmutableArray<>(cells);
		}
		
		@Override
		public int hashCode() {
			return number;
		}

		@Override
		public Iterator<Cell> iterator() {
			return cells.iterator();
		}
		
		/**
		 * Returns the sum of the numeric elements of this sequence
		 * (non-numeric values are ignored).
		 * 
		 * @param function a function for converting the values of the cells in
		 * this sequence to {@link java.lang.Number}s that will be applied to
		 * each cell value before adding it to the sum
		 * @return the sum of the cell values
		 */
		public Number sum(Function<Object, ?> function) {
			return Table.sum(cells, function);
		}
		
		/**
		 * Returns the sum of the numeric elements of this sequence
		 * (non-numeric values are ignored).
		 * 
		 * @return the sum of the cell values
		 */
		public Number sum() {
			return Table.sum(cells, Function.identity());
		}
		
		/**
		 * Returns the average of the numeric elements of this sequence
		 * (non-numeric values are ignored).
		 * 
		 * @param function a function for converting the values of the cells in
		 * this sequence to {@link java.lang.Number}s that will be applied to
		 * each cell value before including it in the average
		 * @return the average of the cell values
		 */
		public Number average(Function<Object, ?> function) {
			return Table.average(cells, function);
		}
		
		/**
		 * Returns the average of the numeric elements of this sequence
		 * (non-numeric values are ignored).
		 * 
		 * @return the average of the cell values
		 */
		public Number average() {
			return Table.average(cells, Function.identity());
		}
	}

	/**
	 * A horizontal sequence of cells.
	 * 
	 * @author Stephen G. Ware
	 */
	public static final class Row extends Sequence {
		
		private Row(Table table, int number, Object label, Cell[] cells) {
			super(table, number, label, cells);
		}
		
		@Override
		public String toString() {
			return table.transform(r -> r == this, c -> true).toString();
		}
	}
	
	/**
	 * A vertical sequence of cells.
	 * 
	 * @author Stephen G. Ware
	 */
	public static final class Column extends Sequence {
		
		private Column(Table table, int number, Object label, Cell[] cells) {
			super(table, number, label, cells);
		}
		
		@Override
		public String toString() {
			return table.transform(r -> true, c -> c == this).toString();
		}
	}
	
	/**
	 * An individual unit of data in the table.
	 * 
	 * @author Stephen G. Ware
	 */
	public static final class Cell {
		
		/** The table this cell belongs to */
		public final Table table;
		
		/** The row this cell belongs to */
		public final Row row;
		
		/** The column this cell belongs to */
		public final Column column;
		
		/** The current value of this cell */
		public Object value;
		
		private Cell(Table table, Row row, Column column) {
			this.table = table;
			this.row = row;
			this.column = column;
		}
		
		@Override
		public int hashCode() {
			return (row.number * column.cells.size()) + column.number;
		}
		
		@Override
		public String toString() {
			return table.transform(r -> r == row, c -> c == column).toString();
		}
	}
	
	/** The table's rows */
	public final ImmutableArray<Row> rows;
	private final HashMap<Object, Row> rowsByLabel = new HashMap<>();
	
	/** The table's columns */
	public final ImmutableArray<Column> columns;
	private final HashMap<Object, Column> columnsByLabel = new HashMap<>();
	
	/** The table's cells */
	public final ImmutableArray<Cell> cells;
	
	/**
	 * Constructs a new, empty table with rows and columns that have the given
	 * labels.
	 * 
	 * @param rowLabels the labels of the rows
	 * @param columnLabels the labels of the columns
	 * @throws IllegalArgumentException if any row labels are duplicates or one
	 * another or any column labels are duplicates of one another
	 */
	public Table(Object[] rowLabels, Object[] columnLabels) {
		if(rowLabels.length == 0)
			throw new IllegalArgumentException("Table must have at least 1 row.");
		if(columnLabels.length == 0)
			throw new IllegalArgumentException("Table must have at least 1 column.");
		Row[] rows = new Row[rowLabels.length];
		Column[] columns = new Column[columnLabels.length];
		Cell[] cells = new Cell[rowLabels.length * columnLabels.length];
		Cell[][] cellsByRow = new Cell[rowLabels.length][columnLabels.length];
		Cell[][] cellsByColumn = new Cell[columnLabels.length][rowLabels.length];
		for(int i=0; i<rowLabels.length; i++) {
			Row row = new Row(this, i, rowLabels[i], cellsByRow[i]);
			rows[i] = row;
			if(rowsByLabel.containsKey(row.label))
				throw new IllegalArgumentException("Duplicate row label \"" + row.label + "\".");
			else
				rowsByLabel.put(row.label, row);
		}
		for(int i=0; i<columnLabels.length; i++) {
			Column column = new Column(this, i, columnLabels[i], cellsByColumn[i]);
			columns[i] = column;
			if(columnsByLabel.containsKey(column.label))
				throw new IllegalArgumentException("Duplicate column label \"" + column.label + "\".");
			else
				columnsByLabel.put(column.label, column);
		}
		for(int r=0; r<rowLabels.length; r++) {
			for(int c=0; c<columnLabels.length; c++) {
				Cell cell = new Cell(this, rows[r], columns[c]);
				cellsByRow[r][c] = cell;
				cellsByColumn[c][r] = cell;
				cells[(r * columnLabels.length) + c] = cell;
			}
		}
		this.rows = new ImmutableArray<>(rows);
		this.columns = new ImmutableArray<>(columns);
		this.cells = new ImmutableArray<>(cells);
	}
	
	/**
	 * Constructs a new, empty table with rows and columns that have the given
	 * labels.
	 * 
	 * @param rowLabels the labels of the rows
	 * @param columnLabels the labels of the columns
	 * @throws IllegalArgumentException if any row labels are duplicates or one
	 * another or any column labels are duplicates of one another
	 */
	public Table(Iterable<?> rowLabels, Iterable<?> columnLabels) {
		this(Utilities.toArray(rowLabels, Object.class), Utilities.toArray(columnLabels, Object.class));
	}
	
	@Override
	public String toString() {
		StringWriter string = new StringWriter();
		String[][] values = new String[rows.size() + 1][columns.size() + 1];
		values[0][0] = "";
		for(Row row : rows)
			values[row.number + 1][0] = toString(row.label);
		for(Column column : columns)
			values[0][column.number + 1] = toString(column.label);
		for(Cell cell : cells)
			values[cell.row.number + 1][cell.column.number + 1] = toString(cell.value);
		for(int r=0; r<values.length; r++)
			for(int c=0; c<values[0].length; c++)
				values[r][c] = values[r][c].replaceAll("\\s+"," ");
		int[] columnLengths = new int[columns.size() + 1];
		for(int r=0; r<values.length; r++)
			for(int c=0; c<values[0].length; c++)
				columnLengths[c] = Math.max(columnLengths[c], values[r][c].length());
		string.append(string(' ', columnLengths[0]));
		for(int c=1; c<columnLengths.length; c++)
			string.append(" " + rightPad(values[0][c], columnLengths[c]));
		for(int r=1; r<values.length; r++) {
			string.append("\n" + rightPad(values[r][0], columnLengths[0]));
			for(int c=1; c<values[r].length; c++)
				string.append(" " + leftPad(values[r][c], columnLengths[c]));
		}
		return string.toString();
	}
	
	private static final String toString(Object object) {
		if(object == null)
			return "";
		else
			return object.toString();
	}
	
	private static final String string(char c, int times) {
		char[] chars = new char[times];
		for(int i=0; i<times; i++)
			chars[i] = c;
		return new String(chars);
	}
	
	private static final String rightPad(String string, int length) {
		return string + string(' ', length - string.length());
	}
	
	private static final String leftPad(String string, int length) {
		return string(' ', length - string.length()) + string;
	}
	
	@Override
	public Table clone() {
		return transform(object -> object);
	}
	
	/**
	 * Returns the row with the given label.
	 * 
	 * @param label the label object
	 * @return the row with this label
	 * @throws IllegalArgumentException if no row has this label
	 */
	public Row getRow(Object label) {
		Row row = rowsByLabel.get(label);
		if(row == null)
			throw new IllegalArgumentException("There is no row with the label \"" + label + "\".");
		else
			return row;
	}
	
	/**
	 * Returns the column with the given label.
	 * 
	 * @param label the label object
	 * @return the column with this label
	 * @throws IllegalArgumentException if no column has this label
	 */
	public Column getColumn(Object label) {
		Column column = columnsByLabel.get(label);
		if(column == null)
			throw new IllegalArgumentException("There is no column with the label \"" + label + "\".");
		else
			return column;
	}
	
	/**
	 * Returns the cell at the interaction of the row with the given label and
	 * and the column with the given label
	 * 
	 * @param rowLabel the row label object
	 * @param columnLabel the column label object
	 * @return the cell in the given row and column
	 * @throws IllegalArgumentException if no such row or column exists
	 */
	public Cell getCell(Object rowLabel, Object columnLabel) {
		Row row = getRow(rowLabel);
		Column column = getColumn(columnLabel);
		Cell cell = null;
		for(Cell c : row) {
			if(c.column == column) {
				cell = c;
				break;
			}
		}
		return cell;
	}
	
	/**
	 * Returns a new table with an addition row with the given label.
	 * 
	 * @param label the label for the new row
	 * @return the new table
	 */
	public Table addRow(Object label) {		
		return addRowOrColumn((rows, columns) -> rows.add(label));
	}
	
	/**
	 * Returns a new table with an addition column with the given label.
	 * 
	 * @param label the label for the new column
	 * @return the new table
	 */
	public Table addColumn(Object label) {
		return addRowOrColumn((rows, columns) -> columns.add(label));
	}
	
	private final Table addRowOrColumn(BiConsumer<ArrayList<Object>, ArrayList<Object>> consumer) {
		ArrayList<Object> rowLabels = new ArrayList<>();
		rows.forEach(row -> rowLabels.add(row.label));
		ArrayList<Object> columnLabels = new ArrayList<>();
		columns.forEach(column -> columnLabels.add(column.label));
		consumer.accept(rowLabels, columnLabels);
		Table result = new Table(rowLabels, columnLabels);
		copy(this, result);
		return result;
	}
	
	private static final void copy(Table from, Table to) {
		for(Cell cell : to.cells)
			if(from.rowsByLabel.containsKey(cell.row.label) && from.columnsByLabel.containsKey(cell.column.label))
				cell.value = from.getCell(cell.row.label, cell.column.label).value;
	}
	
	private static final String TOTAL_LABEL = "Total";
	
	/**
	 * Returns a new table with an additional row whose cells contain the
	 * totals of each column. See {@link Sequence#sum(Function)}.
	 * 
	 * @param function a function for converting the values of the cells in
	 * each column to {@link java.lang.Number}s that will be applied to each
	 * cell value before adding it to the sum
	 * @return the new table
	 */
	public Table addTotalRow(Function<Object, ?> function) {
		Table result = addRow(TOTAL_LABEL);
		for(Column column : columns)
			result.getCell(TOTAL_LABEL, column.label).value = column.sum(function);
		return result;
	}
	
	/**
	 * Returns a new table with an additional row whose cells contain the
	 * totals of each column. See {@link Sequence#sum(Function)}.
	 * 
	 * @return the new table
	 */
	public Table addTotalRow() {
		return addTotalRow(Function.identity());
	}
	
	/**
	 * Returns a new table with an additional column whose cells contain the
	 * totals of each row. See {@link Sequence#sum(Function)}.
	 * 
	 * @param function a function for converting the values of the cells in
	 * each row to {@link java.lang.Number}s that will be applied to each cell
	 * value before adding it to the sum
	 * @return the new table
	 */
	public Table addTotalColumn(Function<Object, ?> function) {
		Table result = addColumn(TOTAL_LABEL);
		for(Row row : rows)
			result.getCell(row.label, TOTAL_LABEL).value = row.sum(function);
		return result;
	}
	
	/**
	 * Returns a new table with an additional column whose cells contain the
	 * totals of each row. See {@link Sequence#sum(Function)}.
	 * 
	 * @return the new table
	 */
	public Table addTotalColumn() {
		return addTotalColumn(Function.identity());
	}
	
	private static final Number sum(ImmutableArray<Cell> cells, Function<Object, ?> function) {
		Number sum = 0;
		for(Cell cell : cells) {
			Object value = function.apply(cell.value);
			if(value instanceof Number)
				sum = Utilities.add(sum, (Number) value);
		}
		return sum;
	}
	
	private static final String AVERAGE_LABEL = "Average";
	
	/**
	 * Returns a new table with an additional row whose cells contain the
	 * averages of each column. See {@link Sequence#average(Function)}.
	 * 
	 * @param function a function for converting the values of the cells in
	 * each column to {@link java.lang.Number}s that will be applied to
	 * each cell value before including it in the average
	 * @return the new table
	 */
	public Table addAverageRow(Function<Object, ?> function) {
		Table result = addRow(AVERAGE_LABEL);
		for(Column column : columns)
			result.getCell(AVERAGE_LABEL, column.label).value = column.average(function);
		return result;
	}
	
	/**
	 * Returns a new table with an additional row whose cells contain the
	 * averages of each column. See {@link Sequence#average(Function)}.
	 * 
	 * @return the new table
	 */
	public Table addAverageRow() {
		return addAverageRow(Function.identity());
	}
	
	/**
	 * Returns a new table with an additional column whose cells contain the
	 * averages of each row. See {@link Sequence#average(Function)}.
	 * 
	 * @param function a function for converting the values of the cells in
	 * each row to {@link java.lang.Number}s that will be applied to each cell
	 * value before including it in the average
	 * @return the new table
	 */
	public Table addAverageColumn(Function<Object, ?> function) {
		Table result = addColumn(AVERAGE_LABEL);
		for(Row row : rows)
			result.getCell(row.label, AVERAGE_LABEL).value = row.average(function);
		return result;
	}
	
	/**
	 * Returns a new table with an additional column whose cells contain the
	 * averages of each row. See {@link Sequence#average(Function)}.
	 * 
	 * @return the new table
	 */
	public Table addAverageColumn() {
		return addAverageColumn(Function.identity());
	}
	
	private static final Number average(ImmutableArray<Cell> cells, Function<Object, ?> function) {
		Number sum = 0;
		int count = 0;
		for(Cell cell : cells) {
			Object value = function.apply(cell.value);
			if(value instanceof Number) {
				sum = Utilities.add(sum, value);
				count++;
			}
		}
		if(count == 0)
			return 0;
		else
			return Utilities.divide(sum, count);
	}
	
	/**
	 * Returns a new table which only keeps the row and columns of this table
	 * that are identified by two predicates and whose cell values are the
	 * results of applying the given function to the values in this table.
	 * 
	 * @param rows a predicate describing which rows to copy
	 * @param columns a predicate describing which columns to copy
	 * @param transform a function to transform the values of the cells
	 * @return the new table
	 */
	public Table transform(Predicate<? super Row> rows, Predicate<? super Column> columns, Function<Object, ?> transform) {
		ArrayList<Object> rowLabels = new ArrayList<>();
		HashMap<Object, Object> rowLabelReverse = new HashMap<>();
		for(Row row : this.rows) {
			if(rows.test(row)) {
				Object label = transform.apply(row.label); 
				rowLabels.add(label);
				rowLabelReverse.put(label, row.label);
			}
		}
		ArrayList<Object> columnLabels = new ArrayList<>();
		HashMap<Object, Object> columnLabelReverse = new HashMap<>();
		for(Column column : this.columns) {
			if(columns.test(column)) {
				Object label = transform.apply(column.label); 
				columnLabels.add(label);
				columnLabelReverse.put(label, column.label);
			}
		}
		Table table = new Table(rowLabels, columnLabels);
		for(Cell cell : table.cells)
			cell.value = transform.apply(getCell(rowLabelReverse.get(cell.row.label), columnLabelReverse.get(cell.column.label)).value);
		return table;
	}
	
	/**
	 * Returns a new table which only keeps the row and columns of this table
	 * that are identified by two predicates.
	 * 
	 * @param rows a predicate describing which rows to copy
	 * @param columns a predicate describing which columns to copy
	 * @return the new table
	 */
	public Table transform(Predicate<? super Row> rows, Predicate<? super Column> columns) {
		return transform(rows, columns, object -> object);
	}
	
	/**
	 * Returns a new table whose cell values are the results of applying the
	 * given function to the values in this table.
	 * 
	 * @param transform a function to transform the values of the cells
	 * @return the new table
	 */
	public Table transform(Function<Object, ?> transform) {
		return transform(r -> true, c -> true, transform);
	}
	
	/**
	 * Returns a new table whose rows and columns have been sorted according to
	 * the given comparators.
	 * 
	 * @param rowComparator defines how the rows should be sorted
	 * @param columnComparator defined how the columns should be sorted
	 * @return the new table
	 */
	public Table sort(Comparator<? super Row> rowComparator, Comparator<? super Column> columnComparator) {
		ArrayList<Row> rows = new ArrayList<>();
		for(Row row : this.rows)
			rows.add(row);
		Collections.sort(rows, rowComparator);
		ArrayList<Column> columns = new ArrayList<>();
		for(Column column : this.columns)
			columns.add(column);
		Collections.sort(columns, columnComparator);
		Table table = new Table(rows.stream().map(row -> row.label)::iterator, columns.stream().map(column -> column.label)::iterator);
		copy(this, table);
		return table;
	}
	
	/**
	 * Returns a new table whose rows have been sorted according to the given
	 * comparator.
	 * 
	 * @param comparator defines how the rows should be sorted
	 * @return the new table
	 */
	public Table sortByRow(Comparator<? super Row> comparator) {
		return sort(comparator, (c1, c2) -> 0);
	}
	
	/**
	 * Returns a new table whose columns have been sorted according to the
	 * given comparator.
	 * 
	 * @param comparator defines how the columns should be sorted
	 * @return the new table
	 */
	public Table sortByColumn(Comparator<? super Column> comparator) {
		return sort((r1, r2) -> 0, comparator);
	}
	
	/**
	 * Returns an HTML representation of this table.
	 * 
	 * @return an HTML string
	 */
	public String toHTML() {
		StringWriter writer = new StringWriter();
		writer.append("<table>\n\t<tr>\n\t\t<th></th>");
		for(Column column : columns)
			writer.append("\n\t\t<th>" + toString(column.label) + "</th>");
		writer.append("\n\t</tr>");
		for(Row row : rows) {
			writer.append("\n\t<tr>\n\t\t<th>" + toString(row.label) + "</th>");
			for(Cell cell : row)
				writer.append("\n\t\t<td>" + toString(cell.value) + "</td>");
			writer.append("\n\t</tr>");
		}
		writer.append("\n</table>");
		return writer.toString();
	}
}
