from pddl import TypedObject, Disjunction, Falsity, Literal
from rin17_invariants import Rin17Assignement, Rin17Candidate
from rin_invariant_util import build_literal, get_from_joined_list


class ResolutionMonitor:
    type_to_variable_counter_name = {}

    def __init__(self, type_names):
        for type_name in type_names:
            self.type_to_variable_counter_name[type_name] = 0

    @staticmethod
    def pop_next_object_name(type_name):
        next_free_name = type_name + str(ResolutionMonitor.type_to_variable_counter_name[type_name])
        ResolutionMonitor.type_to_variable_counter_name[type_name] = \
            ResolutionMonitor.type_to_variable_counter_name[type_name] + 1
        return next_free_name


def rename_variables_in_literal(atom, variable_to_new_name) -> Literal:
    for arg in atom.args:
        if arg in variable_to_new_name.keys():
            break
    else:
        return atom
    new_atom_args = []
    for arg in atom.args:
        if arg in variable_to_new_name.keys():
            new_atom_args.append(variable_to_new_name[arg])
        else:
            new_atom_args.append(arg)
    return build_literal(atom.predicate, new_atom_args, atom.negated)


def rename_variables(candidate: Rin17Candidate, variable_to_new_name) -> Rin17Candidate:
    new_variables = []
    for variable in candidate.get_variables():
        if variable.name in variable_to_new_name:
            if "?" in variable_to_new_name[variable.name]:
                new_variables.append(TypedObject(variable_to_new_name[variable.name], variable.type_name))
        else:
            new_variables.append(variable)

    def var_or_object(name, type_name):
        if "?" in name:
            return TypedObject(name, type_name)
        else:
            return name

    new_iec = []
    for iec in candidate.inequality_constraint:
        if type(iec[0]) == TypedObject and iec[0].name in variable_to_new_name and type(iec[1]) == TypedObject and \
                iec[1].name in variable_to_new_name:
            new_iec.append((var_or_object(variable_to_new_name[iec[0].name], iec[0].type_name),
                            var_or_object(variable_to_new_name[iec[1].name], iec[1].type_name)))
            continue
        if type(iec[0]) == TypedObject and iec[0].name in variable_to_new_name:
            new_iec.append((var_or_object(variable_to_new_name[iec[0].name], iec[0].type_name),
                            iec[1]))
            continue
        if type(iec[1]) == TypedObject and iec[1].name in variable_to_new_name:
            new_iec.append((iec[0],
                            var_or_object(variable_to_new_name[iec[1].name], iec[1].type_name)))
            continue
        new_iec.append(iec)

    new_parts = []
    for atom in candidate.disjunction.parts:
        new_parts.append(rename_variables_in_literal(atom, variable_to_new_name))
    new_disjunction_left = Disjunction(new_parts)

    return Rin17Candidate(new_variables, new_disjunction_left, new_iec)


# inequcon: 2-tuple with objects (str) and variables (TypedObjects)
def test_assignment_against_inequality(assignment: Rin17Assignement, inequality_constraints):
    grounded_inequality_constraints = []
    for iec in inequality_constraints:
        if type(iec[0]) == TypedObject and type(iec[1]) == TypedObject:
            grounded_inequality_constraints.append((assignment.get(iec[0]), assignment.get(iec[1])))
            continue
        if type(iec[1]) == TypedObject:
            grounded_inequality_constraints.append((iec[0], assignment.get(iec[1])))
            continue
        if type(iec[0]) == TypedObject:
            grounded_inequality_constraints.append((assignment.get(iec[0]), iec[1]))
            continue
        grounded_inequality_constraints.append(iec)

    for (a, b) in grounded_inequality_constraints:
        if a == b:
            return False
    else:
        return True


def partially_ground_atom(atom, variable, propositional_object, type_name_to_objects):
    if propositional_object not in type_name_to_objects[variable.type_name]:
        return False, None
    new_vars = list(atom.args)
    new_vars[new_vars.index(variable.name)] = propositional_object
    new_atom = build_literal(atom.predicate, new_vars, atom.negated)
    return True, new_atom


def distinctify_variables(left: Rin17Candidate, left_atom: Literal, right: Rin17Candidate, right_atom: Literal):
    # ensure the same names for the variables in the Literals that will be resolved away
    # if these variables appear somewhere in the formulas, they need to be adapted there too
    left_variable_to_new_name = {}
    for i in range(len(left_atom.args)):
        if left_atom.args[i].startswith("?") and right_atom.args[i].startswith("?"):
            left_variable_to_new_name[left_atom.args[i]] = right_atom.args[i]

    # First examine if there are variables with the same name, if so, find a new name for them
    variable_names_right = [variable.name for variable in right.get_variables()]
    variable_names_left_old = [variable.name for variable in left.get_variables()]
    for variable in left.get_variables():
        if variable.name in variable_names_right and variable.name not in left_variable_to_new_name.keys():
            i = 0
            name_candidate = variable.name + str(i)
            used_names = list()
            used_names.extend(variable_names_right)
            used_names.extend(variable_names_left_old)
            if left_variable_to_new_name:
                used_names.extend(left_variable_to_new_name.values())
            while name_candidate in used_names:
                i += 1
                name_candidate = variable.name + str(i)
            left_variable_to_new_name[variable.name] = name_candidate

    # perform the actual renaming
    if left_variable_to_new_name:
        left_atom = rename_variables_in_literal(left_atom, left_variable_to_new_name)
        left = rename_variables(left, left_variable_to_new_name)
    # now we have left and right with disjoint variable names except for the variables appearing in the resolving atoms
    return left, left_atom


def synchronize_objects(left: Rin17Candidate, right: Rin17Candidate, left_atom: Literal, right_atom: Literal,
                        type_name_to_objects):
    # determins if resolution is possible
    # returns the renaming details required for the left formula
    assert (type(left) == Rin17Candidate and type(right) == Rin17Candidate)
    assert (isinstance(left_atom, Literal) and isinstance(right_atom, Literal))  # Atom or NegatedAtom
    assert (right_atom.predicate == left_atom.predicate and right_atom.negated != left_atom.negated)

    # case disdiction:
    # a) ?v1; ?v2  (only variables)           (already good)
    # b) ?v1; o1  (variables while objects)   (set v1 to o1, test inequality)
    # c) ?v1, ?v2; ?v1, o1 (mixture)          (set v2 to o1, test inequality)
    # d) o1, o1                               (already good)

    # cases (b) and (c), ground all variable necessary:
    # at this point thanks to unify we know, that cases (a) and (d) are handled,
    # so the "if" can only trigger in either case (b) or (c)
    variable_to_object = {}
    for i in range(len(left_atom.args)):
        if left_atom.args[i] != right_atom.args[i]:  # we have already renamed variables, so if this comparison fails,
            # we know that one of them must be an object while the other is a variable,
            # OR (both are objects AND don't match) -> we stop knowing resolution is impossible
            variable, target_object = None, None  # silence not-declared warning
            if left_atom.args[i].startswith("?"):
                variable = left.get_variable_from_name(left_atom.args[i])
                target_object = right_atom.args[i]
                success, left_atom = partially_ground_atom(left_atom, variable, target_object, type_name_to_objects)
            elif right_atom.args[i].startswith("?"):
                variable = right.get_variable_from_name(right_atom.args[i])
                target_object = left_atom.args[i]
                success, right_atom = partially_ground_atom(right_atom, variable, target_object,
                                                            type_name_to_objects)
            else:
                success = False  # there are two objects and they differ, so resolution is not possible
            if not success:
                return None, None, None, None  # failure
            variable_to_object[variable] = target_object

    variable_name_to_object = {}
    for var in variable_to_object:
        variable_name_to_object[var.name] = variable_to_object[var]
    left = rename_variables(left, variable_name_to_object)
    right = rename_variables(right, variable_name_to_object)

    return left_atom, right_atom, left, right


def build_resolvent(left: Rin17Candidate, left_atom: Literal, right: Rin17Candidate, right_atom: Literal,
                    type_name_to_objects):
    # Disjunction
    left_parts = list(left.disjunction.parts)
    left_parts.remove(left_atom)
    right_parts = list(right.disjunction.parts)
    right_parts.remove(right_atom)
    combined_parts = left_parts + right_parts
    if not combined_parts:
        return Falsity()

    # Variables
    old_variables = set(left.get_variables())
    old_variables.update(right.get_variables())
    combined_variables = set()
    for variable in old_variables:
        for atom in combined_parts:
            if variable.name in [a for a in atom.args]:
                combined_variables.add(variable)
                break
    # implicitly removes the variables no longer required
    obsolite_variables = old_variables.difference(combined_variables)

    # inequality constraints
    itermidiate_inequality_constraints = set(left.inequality_constraint)
    itermidiate_inequality_constraints.update(right.inequality_constraint)
    combined_inequality_constraint = []
    for o_iec in itermidiate_inequality_constraints:
        if o_iec[0] in obsolite_variables or o_iec[1] in obsolite_variables:
            continue
        else:
            combined_inequality_constraint.append(o_iec)

    if not resoltution_possible(combined_variables, combined_inequality_constraint, type_name_to_objects):
        return False
    return Rin17Candidate(combined_variables, Disjunction(combined_parts),
                          combined_inequality_constraint)


def resoltution_possible(combined_variables, combined_ineqc, type_name_to_objects):
    # As long as there is a single possible assignment fullfilling the inequalities, it is possible to create a valid
    # resolvent
    assignment = Rin17Assignement(combined_variables, type_name_to_objects)
    while True:
        if test_assignment_against_inequality(assignment, combined_ineqc):
            return True
        else:
            if assignment.increase():
                continue
            else:
                return False


def resolve(right: Rin17Candidate, left: Rin17Candidate, type_name_to_objects):
    left_atoms = left.disjunction.parts
    right_atoms = right.disjunction.parts
    for left_atom in left_atoms:
        for right_atom in right_atoms:
            if right_atom.predicate == left_atom.predicate and right_atom.negated != left_atom.negated:
                # rename variables to be consistent
                l_u, left_atom_u = distinctify_variables(left, left_atom, right, right_atom)
                # replace variables with objects if necessary
                ur_left_atom, ur_right_atom, ur_left, ur_right = synchronize_objects(l_u, right, left_atom_u,
                                                                                     right_atom,
                                                                                     type_name_to_objects)
                if not ur_left or not ur_right:
                    # object synchronization failed
                    return False

                # build the resolvent
                resolvent = build_resolvent(ur_left, ur_left_atom, ur_right, ur_right_atom, type_name_to_objects)
                return resolvent
    return False


def find_resolvents(previous_kb, resulting_kb, start_index, type_name_to_objects):
    length_previous_kb = len(previous_kb)
    i = start_index
    while i < length_previous_kb + len(resulting_kb):
        j = 0
        while j < i:
            left = get_from_joined_list(i, previous_kb, resulting_kb, length_previous_kb)
            right = get_from_joined_list(j, previous_kb, resulting_kb, length_previous_kb)
            resolvant = resolve(right, left, type_name_to_objects)
            if resolvant:
                yield resolvant
            j += 1
        i += 1


def rin17_resolution(kb, type_name_to_objects, start_index=0):  # candidate: arguments (typed objects = variables) ,
    # disjunction (of Atoms schematic), inequality_constraint (over typed objects = variables)
    # Note:
    # kb is analoge to CNF but over rin17 invariant candidates, implemented as list<Rin17_candidate>
    # return left_atom, right_atom, map_variables

    result = []
    for resolvent in find_resolvents(kb, result, start_index, type_name_to_objects):
        if type(resolvent) is Falsity:
            return resolvent
        if resolvent not in result and resolvent not in kb:
            result.append(resolvent)
    return result
