import options
from pddl import Falsity, Conjunction, PropositionalAction, NegatedAtom, Atom, Disjunction
from rin17_invariants import Rin17Candidate
from util import translator_invariant_log


def get_from_joined_list(index, a, b, length_a):
    if index < length_a:
        return a[index]
    else:
        return b[index - length_a]


def regress_strips(o: PropositionalAction, c: Conjunction):
    assert type(c) is Conjunction
    assert type(o) is PropositionalAction
    result = list(c.parts)  # initiate with precon and only add what is not deleted or added

    # Explanation:
    #   At this point, we regress a formula with literals also containing negated Atoms.
    #   Therefore, in contrast to naive STRIPS regression, we must pay attention to the delete effects.
    #   The delete effects are essentially negated add effects for our purpose here.
    #   Furthermore, the WHILE instructions below can NOT be replaced by IF.
    #   This is a result of the naiveness of the invariant search algorithm, which produces trivially simplifiable
    #   formulas like for example ( a v a)
    for condition, deletion in o.del_effects:
        assert not condition  # STRIPS operators should never have conditional effects
        if deletion in result:
            return Falsity()
        while deletion.negate() in result:  # WHILE intentional, see Explanation
            result.pop(result.index(deletion.negate()))
    for condition, addition in o.add_effects:
        assert not condition  # STRIPS operators should never have conditional effects
        while addition in result:  # WHILE intentional, see Explanation
            result.pop(result.index(addition))
        if addition.negate() in result:
            return Falsity()

    result.extend(o.precondition)
    return Conjunction(sorted(result))


def remove_trivial_operators(operators):
    """
        Removes all operators where PRECONDITIONS model ADD_EFFECTS united with (point wise negation of DEL_EFFECTS)
        Meaning operators that do never change a state where they would be applicable in.
    Args:
        operators: List of PropositionalAction

    Returns: No return (modifies list inplace)

    """
    for operator in operators:
        interesting = True
        assert operator.add_effects or operator.del_effects, "there should never be an operator " \
                                                             "with only preconditions," \
                                                             " but here we go: " + str(operator)
        for condition, addition in operator.add_effects:
            if addition not in operator.precondition:
                break
        else:
            if operator.add_effects:
                interesting = False

        for condition, deletion in operator.del_effects:
            if deletion.negate() not in operator.precondition:
                break
        else:
            if operator.del_effects:
                interesting = False

        if not interesting:
            operators.remove(operator)


def tarjan(vertices, edges):
    """
    Source: https://en.wikipedia.org/wiki/Tarjan%27s_strongly_connected_components_algorithm
    Args:
        vertices:
        edges:

    Returns:
        set of strongly connected components

    """
    SCCs = []  # strongly connected components
    index = 0
    indices = dict()
    lowlink = dict()
    onStack = set()
    S = []

    def strongconnect(v, index):
        indices[v] = index
        lowlink[v] = index
        index = index + 1
        S.append(v)
        onStack.add(v)
        for [u, w] in edges:
            # set w to the edges that is not v if not already or continue if v is not in the edge
            if not u == v:
                continue
            if w not in indices.keys():
                index = strongconnect(w, index)
                lowlink[v] = min(lowlink[v], lowlink[w])
            else:
                if w in onStack:
                    lowlink[v] = min(lowlink[v], indices[w])
        if lowlink[v] == indices[v]:
            scc = []
            while True:
                w = S.pop()
                onStack.remove(w)
                scc.append(w)
                if w == v:
                    break
            SCCs.append(scc)
        return index

    for v in vertices:
        if v not in indices:
            index = strongconnect(v, index)
    return SCCs


# both positive literals, mixed literals, both negated literals
AVB = "avb"
AVNB = "avnb"
NAVNB = "navnb"
REMAINDER = "remainder"


def get_mutex_edges(invariants_list):
    """
        return a set of edges that contain only positive literals
        (we don't need the information of the negation anymore, they are mutexes)
    :param invariants_list:
    :return:
    """
    result = []
    for invariant in invariants_list:
        mutex_edge = []
        for lit in invariant:
            if lit.negated:
                lit = lit.negate()
            mutex_edge.append(lit)
        result.append(mutex_edge)
    return result


def _non_trivial_invariants(invariants):
    """
        This is a generator that only yields non-trivial invariants.
        An example for trivial invariants are those that only say
        "this atom cannot be true and false at the same time".

    Args:
        invariants: list of Disjunctions representing invariants

    Returns: next non-trivial invariant

    """
    for invariant in invariants:
        if type(invariant) == Disjunction:
            if len(set(invariant.parts)) == 1:  # the formula is a trivial invariant
                continue
            else:
                yield list(invariant.parts)
        if type(invariant) == Rin17Candidate:
            if len(set(invariant.disjunction.parts)) == 1:  # the formula is a trivial invariant
                continue
            else:
                yield invariant


def _distribute_to_result_buckets(invariant, result):
    """

    Args:
        invariant: List with length 2 containing Atoms

    Returns:
        adds invar to one of the four lists in the result
    """

    amount_negated = 0
    if type(invariant) == list:
        disjunction_list = invariant
    if type(invariant) == Rin17Candidate:
        disjunction_list = list(invariant.disjunction.parts)

    assert disjunction_list, "_distribute_to_result_buckets must be called with either a list or a Rin17_candidate"
    assert len(disjunction_list) == 2, "Invariants at this point should only have lenght 2, len(invariant) = " + str(
        len(disjunction_list))

    if disjunction_list[0] == disjunction_list[1]:
        result[REMAINDER].append(invariant)

    for literal in disjunction_list:
        if literal.negated:
            amount_negated += 1

    switcher = {0: lambda invariant: result[AVB].append(invariant),
                1: lambda invariant: result[AVNB].append(invariant),
                2: lambda invariant: result[NAVNB].append(invariant)}

    add_to_suitable_list = switcher.get(amount_negated, None)
    assert add_to_suitable_list, "This function should never be None."
    add_to_suitable_list(invariant)


def invariants_to_types(invariants):
    """

    Args:
        invariants: list of invariants (Disjunctions)

    Returns: equivalent list of mutexes without the trivial invariants

    """
    result = {AVB: [], AVNB: [], NAVNB: [], REMAINDER: []}
    if not invariants:
        return result

    for invariant in _non_trivial_invariants(invariants):
        if type(invariant) == list:  # Rintanen08
            length = len(invariant)
            assert length > 1, "at this point, no atomic invariants should be present."
            if length == 2:
                _distribute_to_result_buckets(invariant, result)
            else:
                assert False, "we are using 2-SAT, there should never be a formula with size greater 2"
        if type(invariant) == Rin17Candidate:  # Rintanen08
            length = len(invariant.disjunction.parts)
            assert length > 1, "at this point, no atomic invariants should be present."
            if length == 2:
                _distribute_to_result_buckets(invariant, result)
            else:
                # create corresponding invariants of length 2
                for i in range(length):
                    for j in range(length - i):
                        temp = Rin17Candidate(invariant.parameter, Disjunction(
                            [invariant.disjunction.parts[i], invariant.disjunction.parts[i + j]]),
                                              invariant.inequality_constraint)
                        _distribute_to_result_buckets(temp, result)
    return result


def partition_to_cliques(nodes, edges):
    """
    Rintanens (2006) algorithm for partitioning a graph into cliques.
    Performance claim in Planning Problems:'In many of our example applications ... this algorithm
    identifies all the maximal cliques because no two maximal cliques share a node.'

    The Algorithm is a fix point iteration starting with one set containing all nodes.
    In each iteration, a node that is missing an edge to at least one other node in the
    set is identified and then extracted to a new set. All nodes from the old set that
    are connected to the removed node switch over to the new set as well.
    This is a greedy algorithm without garanties but reasonable real world performance.
    Args:
        nodes: Atoms to be separated into cliques
        edges: Invariants that constraint mutex groups

    Returns: a set of disjunctive cliques (no maximality garanty)

    """

    def get_node_not_part_of_clique(clique_candidates, edges):
        for clique_index in range(len(clique_candidates)):
            for node_1 in clique_candidates[clique_index]:
                for node_2 in clique_candidates[clique_index]:
                    # edges implemented as lists, but intended as undirected edges, therefore this double check
                    if node_2 is not node_1 and [node_2, node_1] not in edges and [node_1, node_2] not in edges:
                        return clique_index, node_2  # choosing node_1 works just as well
        return None, None

    n = 0
    clique_candidates = [nodes]
    while True:
        clique_index, node = get_node_not_part_of_clique(clique_candidates, edges)
        if node and clique_index is not None:  # testing clique_index must be this, as 0 results in a incorrect False
            n = n + 1
            clique_candidates.append(set([node] + [further_node for further_node in clique_candidates[clique_index] if
                                                   [further_node, node] in edges
                                                   or [node, further_node] in edges]))
            clique_candidates[clique_index].difference_update(clique_candidates[n])
        else:
            break
    found = set()
    for group in clique_candidates:
        for atom in group:
            found.add(atom)
    dif = nodes.difference(found)
    if dif:
        for atom in dif:
            clique_candidates.append([atom])
    if options.invariant_debug:
        if not dif:
            dif = []
        translator_invariant_log("atoms not in any mutex group", len(dif))
    return clique_candidates


def build_literal(predicate, args, negated):
    if negated:
        return NegatedAtom(predicate, args)
    else:
        return Atom(predicate, args)
