package recursion;

import java.io.BufferedReader;
import java.io.FileReader;
import java.io.IOException;
import java.util.HashMap;

import recursion.basicfunctions.NullFunction;
import recursion.basicfunctions.Projection;
import recursion.basicfunctions.SuccFunction;
import recursion.schema.MuRecursion;
import recursion.schema.PrimitiveRecursion;
import recursion.schema.Composition;

public class Recursion {
    private static HashMap<String, Function> namedFunctions = new HashMap<String, Function>();

    /**
     * @param args Filename of file to parse
     */
    public static void main(String[] args) {
        if (args.length != 1) {
            System.err.println("Usage: java recursion.Recursion FILENAME");
            System.exit(1);
        }
        interpretFile(args[0]);
    }

    private static Function getNamedFunction(String functionName, boolean mustExists) {
        // Base cases
        if (functionName.equals("null")) {
            return new NullFunction();
        } else if (functionName.equals("succ")) {
            return new SuccFunction();
        } else if (functionName.matches("pi_([1-9]\\d*)_([1-9]\\d*)")) {
            String[] parameters = functionName.split("_");
            return new Projection(new Integer(parameters[1]), new Integer(parameters[2]));
        } else if (namedFunctions.containsKey(functionName)) {
            return namedFunctions.get(functionName);
        } else if (mustExists) {
            throw new RuntimeException("Could not find function called '" + functionName + "'");
        } else {
            return null;
        }
    }

    private static void interpretFile(String filename) {
        try {
            // Read in the file
            BufferedReader br = new BufferedReader(new FileReader(filename));
            try {
                // Read the file line by line and parse one proof step per line
                String line;
                int lineNumber = 1;
                while ((line = br.readLine()) != null) {
                    try {
                        interpretLine(line);
                        ++lineNumber;
                    } catch (Exception e) {
                        System.err.println("Could not parse line " + lineNumber + ": " + line);
                        System.err.println(e.getMessage());
                    }
                }
            } finally {
                br.close();
            }
        } catch (IOException e) {
            e.printStackTrace();
        }
    }

    private static void interpretLine(String line) {
        line = line.trim();
        if (line.startsWith("#") || line.isEmpty()) {
            return;
        } else if (line.startsWith("print ")) {
            interpretPrintStatement(line.substring(6).trim());
        } else if (line.contains("=")) {
            String[] parts = line.split("=");
            if (parts.length != 2) {
                throw new RuntimeException("Multiple '=' in one line not supported: " + line);
            }
            interpretAssignmentStatement(parts[0].trim(), parts[1].trim());
        } else {
            throw new RuntimeException("Cannot parse line: " + line);
        }

    }

    private static void interpretAssignmentStatement(String functionName, String functionDefinition) {
        if (getNamedFunction(functionName, false) != null) {
            throw new RuntimeException("Cannot redefine the function: '" + functionName + "'");
        }
        String[] parts = functionDefinition.split("[()]");
        String schemaName = parts[0].trim();
        String[] schemaArgs = parts[1].split(",");
        Function f = null;
        if (schemaName.equals("compose")) {
            if (!functionDefinition.matches("(\\w+)\\(\\s*(\\w+)(\\s*,\\s*\\w+){1,}\\s*\\)")) {
                throw new RuntimeException("Cannot parse compose definition. Expected format 'compose(h, g_1, g_2, ...)' but got: " + functionDefinition);
            }
            Function h = getNamedFunction(schemaArgs[0].trim(), true);
            Function[] gs = new Function[schemaArgs.length - 1];
            for (int j = 1; j < schemaArgs.length; ++j) {
                gs[j-1] = getNamedFunction(schemaArgs[j].trim(), true);
            }
            int i = gs.length;
            int k = gs[0].getArity();
            /*
              TODO: Computing i and k here and passing them into
              "Composition" is a hack to transition from the old compose
              syntax quickly. It would be better to compute i and k
              inside Composition. Note also that many of the checks inside
              the Composition constructor are redundant now that i and k
              are determined automatically. We only need the check that
              all g functions have matching arity.
            */
            f = new Composition(k, i, h, gs);
        } else if (schemaName.equals("primitive_recursion")) {
            if (!functionDefinition.matches("(\\w+)\\(\\s*(\\w+)\\s*,\\s*(\\w+)\\s*\\)")) {
                throw new RuntimeException("Cannot parse primitive recursion definition. Expected format 'primitive_recursion(g, h)' but got: " + functionDefinition);
            }
            Function g = getNamedFunction(schemaArgs[0].trim(), true);
            Function h = getNamedFunction(schemaArgs[1].trim(), true);
            f = new PrimitiveRecursion(g, h);
        } else if (schemaName.equals("mu_recursion")) {
            if (!functionDefinition.matches("(\\w+)\\(\\s*(\\w+)\\s*\\)")) {
                throw new RuntimeException("Cannot parse mu recursion definition. Expected format 'mu(f)' but got: " + functionDefinition);
            }
            Function f_inner = getNamedFunction(schemaArgs[0].trim(), true);
            f = new MuRecursion(f_inner);
        } else {
            throw new RuntimeException("Cannot parse assignment statement: " + functionDefinition);
        }
        namedFunctions.put(functionName, f);
        System.out.println("Defined function '" + functionName + "' as a new function with arity " + f.getArity());
    }

    private static void interpretPrintStatement(String functionCall) {
        if (!functionCall.matches("(\\w+)\\(\\s*(\\d+)\\s*(\\s*,\\s*\\d+)*\\s*\\)")) {
            throw new RuntimeException("Cannot parse print statement. Expected format 'functionName(1,2,3,4)' but got " + functionCall);
        }
        String[] parts = functionCall.split("[()]");
        String functionName = parts[0];
        String functionArgs = parts[1];
        Function f = getNamedFunction(functionName, true);
        String[] stringArgs = functionArgs.split(",");
        int[] args = new int[stringArgs.length];
        for (int i = 0; i < stringArgs.length; i++) {
            args[i] = new Integer(stringArgs[i].trim());
        }
        try {
            int res = f.compute(args);
            System.out.println(functionCall + " = " + res);
        } catch (Exception e) {
            System.out.println(e.getMessage());
        }
    }



}
