#! /usr/bin/env python3
import copy
import datetime
import os
from collections import defaultdict
from os import mkdir

import options
import timers
from pddl import Disjunction, Conjunction, Falsity, Atom, NegatedAtom, Action, PropositionalAction, Predicate, \
    TypedObject, Type, Task
from rin08_invariant_finder import build_pointerlist, satisfies
from rin17_invariants import Rin17Candidate, Rin17PartialOperator
from rin17_resolution import rin17_resolution, ResolutionMonitor
from rin_invariant_util import regress_strips, remove_trivial_operators, invariants_to_types, \
    partition_to_cliques, get_mutex_edges, NAVNB, build_literal


def prmts(pddl_object, target_type: str):
    nmb_individual_occurrences = 0
    if type(pddl_object) is Action:
        action = pddl_object
        for object, object_type in action.type_map.items():
            if object_type == target_type:
                nmb_individual_occurrences += 1
    # else
    if type(pddl_object) is Predicate:
        for typed_object in pddl_object.arguments:
            if typed_object.type_name == target_type:
                nmb_individual_occurrences += 1
    return nmb_individual_occurrences


def recursively_build_partial_ground_formula(c, type_name_to_objects, indexes, results, affected_atoms, position):
    if position == len(indexes):
        # deepest point of recursion, build the formula and then examine if it overlaps with affected_atoms
        atoms = c.disjunction.parts
        result_atoms = []
        affection_test_result_atoms = []
        for atom in atoms:
            if set(c.get_type_names(atom.args)).intersection(type_name_to_objects.keys()):
                predicate = atom.predicate
                args = []
                for index in range(len(atom.args)):
                    arg = atom.args[index]
                    if c.get_type_name(arg) in type_name_to_objects:
                        args.append(type_name_to_objects[c.get_type_name(arg)][indexes[
                            c.get_index(arg)]])
                    else:
                        args.append(arg)
                if atom.negated:
                    new_atom = NegatedAtom(predicate, args)
                else:
                    new_atom = Atom(predicate, args)
                affection_test_result_atoms.append(Atom(predicate, args))
                result_atoms.append(new_atom)
            else:
                affection_test_result_atoms.append(atom)
                result_atoms.append(atom)
        iec_hold = True
        # check inequality constraints
        for iec in c.inequality_constraint:
            to_test = []
            for ie in iec:
                if not type(ie) == str and c.get_type_name(ie.name) in type_name_to_objects:
                    to_test.append(type_name_to_objects[c.get_type_name(ie.name)][indexes[c.get_index(ie.name)]])
                else:
                    to_test.append(ie)
            if to_test[0] == to_test[1]:
                iec_hold = False
                break

        # affected atoms and affection_test_result_atoms are both sets of grounded atoms
        if iec_hold and not set(affection_test_result_atoms).isdisjoint(affected_atoms):
            results.append(Disjunction(result_atoms))
        return
    indexes = indexes.copy()
    while indexes[position] < len(type_name_to_objects[c.get_variables()[position].type_name]):
        recursively_build_partial_ground_formula(c, type_name_to_objects, indexes, results, affected_atoms,
                                                 position + 1)
        indexes[position] += 1
    return


def partially_ground_candidate(c: Rin17Candidate, o: PropositionalAction, type_name_to_objects):
    # recursively apply every type_to_objects-mapping possibility for each parameter

    indexes = []
    for i in range(len(c.get_variables())):
        indexes.append(0)

    possible_formulas = []
    pos_debug_help = []
    # never add formulas which are not affected by the operator
    affected_atoms = set([e for c, e in o.add_effects] + [e for c, e in o.del_effects])
    recursively_build_partial_ground_formula(c, type_name_to_objects, indexes, possible_formulas, affected_atoms,
                                             position=0)
    return possible_formulas
    # now we have partially grounded operators as previously in STRIPS, so we just do the same and return a list


# kb: list<Rin17_candidate>
def satisfies_rin17(c, kb, type_name_to_objects):  # c can be either a disjunction, conjunction or rin17_candidate
    # default return is SAT, incomplete SAT test
    # in our case, we want no unsat for sat instances
    # because anything else would lead to incorrect
    # invariants. The chosen incompleteness only leads
    # to less than possible invariants

    if isinstance(c, Falsity):
        return False

    if type(c) == Conjunction:
        c_list = []
        for part in c.parts:
            c_list.append(Rin17Candidate(variables=[], disjunction=Disjunction([part])))
        c = c_list

    if type(c) == Disjunction:
        c = [Rin17Candidate([], disjunction=c)]

    if type(c) == Rin17Candidate:
        c = [c]

    assert type(c) == list, "The satisfies method from our rintan 17 implementation was called with type: " + str(
        type(c))

    kb_original_length = len(kb)
    kb = list(kb)
    kb.extend(c)

    result = rin17_resolution(kb, type_name_to_objects, start_index=kb_original_length)
    return type(result) != Falsity


def get_names(candidate_para):
    result = []
    for para in candidate_para:
        result.append(para.name)
    return result


def weaken(candidate: Rin17Candidate, predicates, type_name_to_objects):
    weakened_candidates = []

    # objects into variables
    if len(candidate.get_variables()) > 1:
        # we can insert every suitable object into every parameter, each combination yielding another new candidate
        for variable in candidate.get_variables():
            for obj in type_name_to_objects[variable.type_name]:
                # replace all occurences of variable with obj in disjunction
                # replace variable in inequality constraints , if one already has an object in there, then handle it
                variables = list(candidate.get_variables())
                variables.remove(variable)

                inequality_constraint = list(candidate.inequality_constraint)
                obj_would_contradict_inequality_constraint = False

                i = 0
                while i < len(inequality_constraint):  # tuple (2) with objects (str) and variables (TypedObjects)
                    iec = inequality_constraint[i]
                    if variable in [ob_va for ob_va in iec if type(ob_va) == TypedObject]:
                        # so at least one element is a variable
                        if obj in [ob_va for ob_va in iec if type(ob_va) == str]:  # if the other is the object
                            obj_would_contradict_inequality_constraint = True
                            break

                        inequality_constraint.pop(i)
                        not_p_member = [e for e in iec if
                                        type(e) != TypedObject or e != variable]  # this will always have length 1
                        if type(not_p_member[0]) != str:
                            inequality_constraint.append((not_p_member[0], obj))
                        else:
                            continue  # forgo index increase as we took away an element of the list
                    i += 1

                if obj_would_contradict_inequality_constraint:
                    break  # if the object contradicts at least one inequality constraint, the object cannot be used

                parts = candidate.disjunction.parts
                new_parts = []
                for literal in parts:
                    if variable.name in literal.args:
                        args = [a if a != variable.name else obj for a in literal.args]

                        new_literal = build_literal(literal.predicate, args, literal.negated)
                        new_parts.append(new_literal)
                    else:
                        new_parts.append(literal)
                disjunction = Disjunction(new_parts)

                weakened_candidates.append(Rin17Candidate(variables, disjunction, inequality_constraint))

    # inequality constraints
    possible_inequalities = []
    closed_set = set()
    candidate_variables = set(candidate.get_variables())
    for var_a in candidate_variables:
        closed_set.add(var_a)
        for var_b in candidate_variables.difference(closed_set):
            if var_a != var_b and var_a.type_name == var_b.type_name and not candidate.is_in_inequaltiy_constraint(
                    (var_a, var_b)):
                possible_inequalities.append((var_a, var_b))

    for iec in possible_inequalities:
        weakened_candidates.append(
            Rin17Candidate(candidate.get_variables(), candidate.disjunction,
                           list(candidate.inequality_constraint) + [iec]))

    if len(candidate.disjunction.parts) < options.rintanen_max_invarsize:
        # add further literals
        for predicate in predicates:
            if not set(predicate.arguments).isdisjoint(set(candidate.get_variables())):
                arguments = []
                for variable in predicate.arguments:
                    if variable not in candidate.get_variables():
                        arguments.append(variable)
                    else:
                        i = 0
                        while variable.name + str(i) in get_names(candidate.get_variables()):
                            i += 1
                        arguments.append(TypedObject(variable.name + str(i), variable.type_name))

                predicate = Predicate(predicate.name, arguments)
            atoms = list(candidate.disjunction.parts)
            neg_list = atoms.copy()
            atoms.append(predicate_to_Atom(predicate))
            dis = Disjunction(atoms)
            new_variables = list(candidate.get_variables()) + predicate.arguments
            weakened_candidates.append(Rin17Candidate(new_variables, dis))
            # and same for Negated Atom
            neg_list.append(predicate_to_Atom(predicate, negated=True))
            neg_dis = Disjunction(neg_list)
            weakened_candidates.append(Rin17Candidate(new_variables, neg_dis))

    return weakened_candidates


def predicate_to_Atom(predicate, negated=False):
    return build_literal(predicate.name, [a.name for a in predicate.arguments], negated)


# A: list[pddl.Atom], I: list[pddl.Atom]
def get_initial_candidates(relevant_predicates, type_name_to_objects, task, atoms):
    predicates_to_initial_literals = {}
    for predicate in relevant_predicates:
        predicates_to_initial_literals[predicate.name] = list()

    for atom in atoms:
        if atom in task.init:
            predicates_to_initial_literals[atom.predicate].append(atom)
        else:
            predicates_to_initial_literals[atom.predicate].append(atom.negate())

    open_list = []
    predicate_to_initial_candiates = {}
    for predicate in relevant_predicates:
        for p in task.predicates:
            if p.name == predicate.name:
                pos = Rin17Candidate(p.arguments, Disjunction([predicate_to_Atom(p, negated=False)]), [])
                neg = Rin17Candidate(p.arguments, Disjunction([predicate_to_Atom(p, negated=True)]), [])
                open_list.append(pos)
                open_list.append(neg)
                predicate_to_initial_candiates[p.name] = [pos, neg]
                break

    result = set()
    expanded = set()
    init = []
    for atoms_with_predicate in predicates_to_initial_literals.values():
        init.extend([Rin17Candidate([], Disjunction([a]), []) for a in atoms_with_predicate])

    initial_candidates_form_predicates = []
    for can_list in predicate_to_initial_candiates.values():
        initial_candidates_form_predicates.extend([c for c in can_list])
    opened_predicates = []
    while open_list:
        candidate = open_list.pop()
        if candidate in result or candidate in expanded:
            continue
        expanded.add(candidate)
        if not satisfies_rin17(candidate, init, type_name_to_objects):
            if candidate in initial_candidates_form_predicates:
                # only expand upon predicates, after negation of it has already been removed
                if predicate_to_initial_candiates[candidate.disjunction.parts[0].predicate][0] in open_list or \
                        predicate_to_initial_candiates[candidate.disjunction.parts[0].predicate][1] in open_list or \
                        predicate_to_initial_candiates[candidate.disjunction.parts[0].predicate][0] in result or \
                        predicate_to_initial_candiates[candidate.disjunction.parts[0].predicate][1] in result:
                    continue
                else:
                    for rel_pred in relevant_predicates:
                        if rel_pred.name == candidate.disjunction.parts[0].predicate:
                            opened_predicates.append(rel_pred)
                            break
            open_list.extend(weaken(candidate, opened_predicates, type_name_to_objects))
        else:
            result.add(candidate)

    return list(result)


def affects(operator: PropositionalAction, formula: Disjunction):
    formula_predicates = set([a.predicate for a in formula.parts])
    affected_predicates = set(
        [e.predicate for c, e in operator.add_effects] + [e.predicate for c, e in operator.del_effects])
    if affected_predicates.intersection(formula_predicates):
        return True
    return False


def get_relevant_predicates(task, atoms):
    result = []
    for predicate in task.predicates:
        for atom in atoms:
            if predicate.name == atom.predicate:
                result.append(predicate)
                break
    return result


def get_unary_predicates_coding_types(task, atoms):
    unary_predicates_indicating_types = []
    atom_predicates = [a.predicate for a in atoms]
    for predicate in task.predicates:
        if predicate.name not in atom_predicates and len(predicate.arguments) == 1:
            unary_predicates_indicating_types.append(predicate)
    return unary_predicates_indicating_types


def typefy_task(task, unary_predicates_indicating_types):
    # objects: if the positive literal of the unary_pred with this object appears in the initial state,
    # change the type of the object, if not yet exists, name it after the predicate(the name)
    # (assert no negative occurrences)
    unary_predicates_indicating_types = [p.name for p in unary_predicates_indicating_types]
    objects_in_initial_unarily = {}
    for atom in task.init:
        if len(atom.args) == 1 and atom.predicate in unary_predicates_indicating_types:
            objects_in_initial_unarily[atom.args[0]] = atom.predicate

    # actions and predicates together,
    actions = []
    predicate_to_occurring_types = defaultdict(list)
    predicates_in_effects_or_preconditions = set()
    for action in task.actions:
        map = {}
        for part in action.precondition.parts:
            if part.predicate in unary_predicates_indicating_types:
                assert len(part.args) == 1
                map[part.args[0]] = part.predicate
        parameters = []
        for parameter in action.parameters:
            if parameter.name in map:
                parameters.append(TypedObject(parameter.name, map[parameter.name]))
            else:
                parameters.append(parameter)
        # here
        precondition = Conjunction(
            [p for p in action.precondition.parts if p.predicate not in unary_predicates_indicating_types])
        new_action = Action(action.name, parameters, action.num_external_parameters, precondition, action.effects,
                            action.cost)
        actions.append(new_action)

        # predicate counting
        all = [e.literal for e in new_action.effects] + [p for p in new_action.precondition.parts]
        for literal in all:
            predicates_in_effects_or_preconditions.add(literal.predicate)
            for arg in literal.args:
                if map[arg] not in predicate_to_occurring_types[literal.predicate]:
                    predicate_to_occurring_types[literal.predicate].append(map[arg])
    objects = []
    type_name_to_type = {t.name: t for t in task.types if t.name in predicates_in_effects_or_preconditions}
    for object in task.objects:
        if object.name in objects_in_initial_unarily:
            objects.append(TypedObject(object.name, objects_in_initial_unarily[object.name]))
            if Type(objects_in_initial_unarily[object.name]) not in type_name_to_type:
                type_name_to_type[objects_in_initial_unarily[object.name]] = Type(
                    objects_in_initial_unarily[object.name])
        else:
            objects.append(object)
    types = type_name_to_type.values()

    predicate_to_arguments_length = {p.name: len(p.arguments) for p in task.predicates}
    predicate_to_occurring_types = {p: predicate_to_occurring_types[p] for p in predicate_to_occurring_types if
                                   len(predicate_to_occurring_types[p]) == predicate_to_arguments_length[p]}
    predicates = []
    for p in task.predicates:
        if p.name not in unary_predicates_indicating_types and p.name in predicates_in_effects_or_preconditions:
            if p.name in predicate_to_occurring_types:
                typed_objects = []
                assert len(p.arguments) == len(predicate_to_occurring_types[p.name])
                for i in range(len(p.arguments)):
                    typed_objects.append(TypedObject(p.arguments[i].name, predicate_to_occurring_types[p.name][i]))
                predicates.append(Predicate(p.name, typed_objects))
                # But indeed most predicates only ever have a specific type in as para
            else:
                predicates.append(p)

    new_task = Task(task.domain_name, task.task_name, task.requirements,
                    types, objects, predicates, task.functions, task.init, task.goal,
                    actions, task.axioms, task.use_min_cost_metric)
    return new_task


def _list_to_string(list_of_rin17_candidates):
    result = ""
    for candidate in list_of_rin17_candidates:
        result += candidate.name
        result += "\n"
    return result


SAT2 = True


# A : Set[pddl.Atom]
# I : List[pddl.Atom] But not all, only the trues and many are not relevant ( z.B. ball(ball2))
# O : List[pddl.PropositionalAction]
# n : integer
def rin17_invariants(task, pgo, type_name_to_objects, atoms):  # S-IRIS
    n = options.rintanen_max_invarsize
    assert n >= 1, "Rintanens Invariant Algorithm cannot produce invariants of size less than one"

    remove_trivial_operators(pgo)
    if options.invariant_debug:
        start_ts = str(datetime.datetime.now()).replace(":", "").replace(".", "").replace(" ", "_")
        root_path = "output/"
        directory_path = root_path + start_ts
        if not os.path.exists(root_path):
            mkdir(root_path)
        if not os.path.exists(directory_path):
            mkdir(directory_path)
        counter = 0
    relevant_predicates = get_relevant_predicates(task, atoms)
    # Note: C is a list of lists, 1. Index is the candidate 2. Index are the conjunctive literals
    C = get_initial_candidates(relevant_predicates, type_name_to_objects, task, atoms)

    if SAT2:
        literal_list = [a for a in atoms] + [a.negate() for a in atoms]
    list_changed = True
    while list_changed:
        with timers.timing("one iteration:"):
            list_changed = False
            Ctemp = copy.deepcopy(C)
            if options.invariant_debug:
                with open(directory_path + "/iteration" + str(counter) + ".txt", "w") as file:
                    file.write("Ctemp in run " + str(counter) + "::\n" + _list_to_string(C))
                # with timers.timing("build kb"):
                # kb = rin08_resolution(Ctemp)
            # else:
            # pass
            # kb = rin08_resolution(Ctemp)
            # kb.extend(Ctemp)
            i = 0
            if SAT2:
                gnd_C = ground_invariants(C, type_name_to_objects)
                pointer_list = build_pointerlist(gnd_C)
            while i < len(C):
                c = C[i]
                cnt_unaffecting = 0
                cnt_pgc = 0
                for o in pgo:
                    if affects(o, c.disjunction):
                        for c_grnd_formula in partially_ground_candidate(c, o, type_name_to_objects):
                            cnt_pgc += 1
                            rnn = regress_strips(o, c_grnd_formula.negate())
                            if SAT2:
                                # c: Conjunction, kb, C=None, literal_list=None
                                sat = satisfies(rnn, pointer_list, gnd_C, literal_list)
                            else:
                                sat = satisfies_rin17(rnn, Ctemp, type_name_to_objects)
                            if sat:
                                C.pop(i)
                                discarded = True
                                list_changed = True
                                # weakening knows n and only produces short enough formulas
                                if len(c.disjunction.parts) <= n:
                                    for candidate in weaken(c, relevant_predicates, type_name_to_objects):
                                        if candidate not in C:
                                            C.append(candidate)
                                # Note: i is not increased because we just caused a shift
                                break
                        else:  # if no operator was found that could produce the formula c
                            continue
                        break
                    else:
                        cnt_unaffecting += 1
                else:
                    i += 1
                if cnt_unaffecting == len(pgo):
                    C.pop(i - 1)
                    list_changed = True
                    i -= 1

            if options.invariant_debug:
                counter += 1

    if options.invariant_debug:
        with open(directory_path + "/final.txt", "w") as file:
            file.write("Ctemp in last run:\n" + _list_to_string(C))
    return C


def recursively_build_partial_ground_operator(o, type_to_objects, position, result):
    if position == len(o.operator.parameters):
        result.append(o.instantate())
        return

    for obj in type_to_objects[o.operator.parameters[position].type_name]:
        next_o = o.copy()
        next_o.partially_instantiate(position, obj)
        recursively_build_partial_ground_operator(next_o, type_to_objects, position + 1, result)
    return


def partially_ground_operators(task):
    type_name_to_L = {}
    for pddl_type in task.types:
        name = pddl_type.name

        max_number_parameters_in_operator = 0
        for o in task.actions:
            max_number_parameters_in_operator = max(max_number_parameters_in_operator, prmts(o, name))

        max_number_parameters_in_predicate = 0
        for p in task.predicates:
            max_number_parameters_in_predicate = max(max_number_parameters_in_predicate, prmts(p, name))

        type_name_to_L[name] = max(max_number_parameters_in_operator, max_number_parameters_in_predicate) + (
                options.rintanen_max_invarsize - 1) * max_number_parameters_in_predicate
    result = []
    type_name_to_objects = {}
    for type_object in task.types:
        objects_of_this_type = [o.name for o in task.objects if o.type_name == type_object.name]
        assert len(objects_of_this_type) >= type_name_to_L[type_object.name]
        type_name_to_objects[type_object.name] = objects_of_this_type[0:type_name_to_L[type_object.name]]
    for o in task.actions:
        partial_o = Rin17PartialOperator(o, {})
        recursively_build_partial_ground_operator(partial_o, type_name_to_objects, 0, result)
    return result, type_name_to_objects
    # go through all operators and produce grounded variant


def get_grounding(task, type_name_to_objects):
    """
    Get a mapping for every type's type_name to the ground objects of this type
    :param task: the current task
    :param type_name_to_objects:
    :return:
    """
    type_name_to_ground_objects = {}
    for type_name in type_name_to_objects:
        ground_objects_of_type = []
        for ground_object in task.objects:
            if ground_object.type_name == type_name:
                ground_objects_of_type.append(ground_object.name)
        assert len(ground_objects_of_type) > 0, "There must always be at least on object to each type in a task"
        type_name_to_ground_objects[type_name] = ground_objects_of_type
    return type_name_to_ground_objects


def ground_invariants(invariants, type_name_to_ground_objects):
    grounded_invariants = set()
    for invariant in invariants:
        grounded_invariants.update(invariant.instantiate(type_name_to_ground_objects))
    return grounded_invariants


def rin17_get_groups(task, atoms):
    unary_predicates_indicating_types = get_unary_predicates_coding_types(task, atoms)
    task = typefy_task(task, unary_predicates_indicating_types)
    pgo, type_name_to_objects = partially_ground_operators(task)
    ResolutionMonitor([t.name for t in task.types])  # static functionality
    invariants = rin17_invariants(task, pgo, type_name_to_objects, atoms)

    type_name_to_ground_objects = get_grounding(task, type_name_to_objects)
    # naive solution: ground all invariant candidates with to original objects
    grounded_invariants = ground_invariants(invariants, type_name_to_ground_objects)
    # sophisticated solution: use schematic invariants smartly to directly infer mutex groups
    #                         (or at least to start with smarter groups in clique algorithm)
    invariants_by_type = invariants_to_types(grounded_invariants)
    mutex_edges = get_mutex_edges(invariants_by_type[NAVNB])
    partitioning = partition_to_cliques(nodes=atoms.copy(), edges=mutex_edges)
    return partitioning
