#! /usr/bin/env python3
from pddl import PropositionalAction, NegatedAtom, Atom, Truth, Disjunction, TypedObject


class Rin17Candidate:
    def __init__(self, variables, disjunction: Disjunction, inequality_constraint=None):
        # This construct for inequaltiy_constraint is mandatory!
        # Python evaluates default arguments once on function definition.
        # Therefore, each Rin17_candidate not given a list would point to the exact same list.
        # Although Severin believes they will not be changed after creation,
        # having them point to the same list could potentially lead to unexpected legacy behaviour in the future.
        if inequality_constraint is None:
            inequality_constraint = []

        self._variables = tuple(variables)  # List of TypedObjects, will be called often, therefore saved as a member
        self.disjunction = disjunction
        # List of tuples of names of Typed Objects
        self.inequality_constraint = tuple(inequality_constraint)

        self.hash = hash((self.__class__, self._variables, self.disjunction, self.inequality_constraint))

        name = "("
        for (a, b) in self.inequality_constraint:
            name += "(" + str(a) + "!=" + str(b) + ")"
        name += ") -- ["
        first = True
        for part in self.disjunction.parts:
            if first:
                first = False
            else:
                name += " , "
            if part.negated:
                name += "¬"
            name += "%s(%s)" % (part.predicate, ", ".join(map(str, part.args)))
        name += "]"
        self.name = name

    def __eq__(self, other):
        """Overrides the default implementation"""
        if isinstance(other, Rin17Candidate):
            return self.hash == other.hash
        return NotImplemented

    def __ne__(self, other):
        """Overrides the default implementation (unnecessary in Python 3)"""
        x = self.__eq__(other)
        if x is not NotImplemented:
            return not x
        return NotImplemented

    def __hash__(self):
        """Overrides the default implementation"""
        return self.hash

    def __repr__(self):
        return '<%s>' % self.name

    def get_variables(self):
        return self._variables

    def get_variable_from_name(self, variable_name):
        for variable in self._variables:
            if variable.name == variable_name:
                return variable
        return None

    def get_type_name(self, variable_name):
        for variable in self._variables:
            if variable.name == variable_name:
                return variable.type_name

    def get_type_names(self, variable_list):
        result = []
        for variable in variable_list:
            result.append(self.get_type_name(variable))
        return result

    def get_index(self, target):
        i = 0
        while i < len(self._variables):
            if self._variables[i].name == target:
                return i
            i += 1
        return None

    def is_in_inequaltiy_constraint(self, other):
        # inequaltiy constraints can be (Typedobject, Typedobject) (Typedobject, str)
        # other can be the same
        for ieqc in self.inequality_constraint:
            a = ieqc[0]
            b = ieqc[1]
            l = other[0]
            k = other[1]
            if type(a) == type(l) and a == l:
                if type(b) == type(k) and b == k:
                    return True
            if type(a) == type(k) and a == k:
                if type(b) == type(l) and b == l:
                    return True
        else:
            return False

    def instantiate(self, type_name_to_ground_objects):
        """
        Generate all groundings possible (if any, 0 is unlikely) under the given mapping.

        :param type_name_to_ground_objects: mapping containing the types of all variables in the candidate self
        :return: A set containing all grounded Disjunctions that result from the candidate self and the mapping given
        """
        grounded_candidates = set()
        self._partially_instantiate(grounded_candidates, type_name_to_ground_objects)

        grounded_invariants = set()
        for gc in grounded_candidates:
            assert (len(gc.get_variables()) == 0), "After grounding (instantiating), " \
                                                   "there should not be any variables in a candidate"
            grounded_invariants.add(gc.disjunction)
        return grounded_invariants

    def _partially_instantiate(self, grounded_invariants, type_name_to_ground_objects):
        """
        Recursively replaces variables with objects until a candidate is produced, where there are no variables anymore.
        To this end, in each recursion step, new candidates with one variable less are generated and this method called
        on these new candidates.

        :param grounded_invariants: list of Rin17_candidates (can be empty, will be filled if any grounding possible)
        :param type_name_to_ground_objects: mapping containing the types of all variables in the candidate self
        :return: None (fills "grounded_invariants" with fully grounded invariants)
        """

        # here because needed below, however importing on top would be circular
        from rin_invariant_util import build_literal

        # recursion stop condition: no more variables that could be grounded
        # i.e. no "?xx" meaning no variables in neither the disjunction nor the inequality constraints
        if len(self._variables) == 0:
            grounded_invariants.add(self)
            return

        # we just take any, the order doesn't have to be enforced, as we remove this variable from the next recursion
        current_variable = self.get_variables()[0]

        for ground_object in type_name_to_ground_objects[current_variable.type_name]:
            # partially ground position with object

            new_inequallity_constraint = []
            inequality_violated = False
            for iec in self.inequality_constraint:
                assert (len(iec) == 2), "All inequality constraints must always have length 2."
                new_iec_parts = []
                for element in iec:
                    if type(element) == TypedObject and element.name == current_variable.name:
                        new_iec_parts.append(ground_object)
                    else:
                        new_iec_parts.append(element)
                if type(new_iec_parts[0]) == type(new_iec_parts[1]) and new_iec_parts[0] == new_iec_parts[1]:
                    inequality_violated = True
                    break
                new_inequallity_constraint.append(tuple(new_iec_parts))

            if inequality_violated:
                continue  # with current object, grounding violates inequalities, but other objects might work

            # second building new disjunction, cannot fail here (in contrast to before), therefore the choice of order
            new_parts = []
            for literal in self.disjunction.parts:
                new_args = []
                for arg in literal.args:
                    if arg == current_variable.name:
                        new_args.append(ground_object)
                    else:
                        new_args.append(arg)
                new_parts.append(build_literal(literal.predicate, new_args, literal.negated))
            # at this point we must test, whether a v not a exists, in that case, the disjunction is a
            # Tautologie and should not be added
            for part in new_parts:
                if part.negate() in new_parts:
                    return
            new_disjunction = Disjunction(new_parts)

            new_variables = list(self.get_variables())
            new_variables.remove(current_variable)

            # continue with the recursion
            more_strongly_grounded_candidate = Rin17Candidate(new_variables, new_disjunction,
                                                              new_inequallity_constraint)
            more_strongly_grounded_candidate._partially_instantiate(grounded_invariants, type_name_to_ground_objects)
        return


class Rin17PartialOperator:  # partially grounded
    def __init__(self, operator, variable_to_object):
        self.operator = operator
        self.variable_to_object = variable_to_object
        self.name = "partial of " + operator.name

    def copy(self):
        return Rin17PartialOperator(self.operator, self.variable_to_object.copy())

    def partially_instantiate(self, position, target):
        target_parameter_name = self.operator.parameters[position].name
        self.variable_to_object[target_parameter_name] = target

    def instantate(self):
        name = self.operator.name
        for parameter in self.operator.parameters:
            name += " " + parameter.name + " " + self.variable_to_object[parameter.name]
        preconditions_pddl = self.operator.precondition.parts
        preconditions = []
        for precondition_pddl in preconditions_pddl:
            args = [self.variable_to_object[arg] for arg in precondition_pddl.args]
            if precondition_pddl.negated:
                precondition = NegatedAtom(precondition_pddl.predicate, args)
            else:
                precondition = Atom(precondition_pddl.predicate, args)
            preconditions.append(precondition)

        effects_pddl = self.operator.effects
        effects = []
        for effect_ppdl in effects_pddl:
            assert type(effect_ppdl.condition) == Truth, "Rintanens Algorithm is only implemented for STRIPS"
            atom = effect_ppdl.literal
            args = [self.variable_to_object[arg] for arg in atom.args]
            if atom.negated:
                effect = NegatedAtom(atom.predicate, args)
            else:
                effect = Atom(atom.predicate, args)
            effects.append(([], effect))

        result = PropositionalAction(name, preconditions, effects, self.operator.cost)
        return result


class Rin17Assignement:
    # assignment must allow random access and key getting (ordered dictionary)
    # dict pointing to list index
    def __init__(self, parameters, type_name_to_objects):  #
        self.type_name_to_objects = type_name_to_objects
        self.variable_to_index = {}
        self.index_to_variable = {}
        self.length = 0
        for para in parameters:
            self.variable_to_index[para] = self.length
            self.index_to_variable[self.length] = para
            self.length += 1
        self.assignment = [0 for p in parameters]
        # for key in self.para_to_index:
        #    self.assignment[self.para_to_index[key]] = type_to_objects[key.type_name][0]

    # increase deterministicly to assure every combination is explored
    def increase(self):
        return self._recursivly_increase(0)

    # return the current object assigned to the parameter
    def get(self, variable):
        return self.type_name_to_objects[variable.type_name][self.assignment[self.variable_to_index[variable]]]

    def _recursivly_increase(self, position):
        if position == self.length:
            return False
        else:
            if self.assignment[position] < len(
                    self.type_name_to_objects[self.index_to_variable[position].type_name]) - 1:
                self.assignment[position] += 1
                return True
            else:
                self.assignment[position] = 0
                return self._recursivly_increase(position + 1)
