package recursion.schema;

import java.util.Arrays;

import recursion.Function;

public class PrimitiveRecursion extends Function {
    private final Function g;
    private final Function h;

    public PrimitiveRecursion(Function g, Function h) {
        super(g.getArity() + 1);
        if (h.getArity() != g.getArity() + 2) {
            throw new RuntimeException("Tried to apply primitive recursion schema where the base case (g) has arity " +
                + g.getArity() + " and the recursive case (h) has arity " + h.getArity()
                + " instead of " + (g.getArity() + 2));
        }
        this.g = g;
        this.h = h;
    }

    @Override
    protected int computeValue(int[] args) {
        if (args[0] == 0) {
            int[] gArgs = Arrays.copyOfRange(args, 1, args.length);
            return g.compute(gArgs);
        } else {
            int[] recArgs = Arrays.copyOf(args, args.length);
            int n = args[0] - 1;
            recArgs[0] = n;
            int recValue = compute(recArgs);

            int[] hArgs = new int[h.getArity()];
            hArgs[0] = recValue;
            hArgs[1] = n;
            for (int i = 1; i < args.length; ++i) {
                hArgs[i+1] = args[i];
            }
            return h.compute(hArgs);
        }
    }
}
