#! /usr/bin/env python3
import datetime
import time
from collections import defaultdict
from os import mkdir

import options
import timers
import copy
from pddl import Disjunction, Conjunction, Falsity
from rin_invariant_util import get_from_joined_list, regress_strips, remove_trivial_operators, tarjan, \
    invariants_to_types, get_mutex_edges, partition_to_cliques, NAVNB
from util import invar_search_output, log_number_mutexes, translator_invariant_log

SAT2 = True

start_time = 0


# kb: list<Disjunction>
def satisfies(c: Conjunction, kb, C=None, literal_list=None, change=False):
    # default return is SAT, incomplete SAT test
    # in our case, we want no unsat for sat instances
    # because anything else would lead to incorrect
    # invariants. The chosen incompleteness only leads
    # to less than possible invariants
    if isinstance(c, Falsity):
        return False

    # if all disjunctions in the CNF c are already in our kb, its trivially satisfiable
    if not SAT2 and not change:
        for disjunction in c.parts:
            if disjunction not in kb:
                break
        else:
            return True

    if SAT2 and not change:
        assert C, "for SAT 2, We must give C into satisfies"
        assert literal_list, "for SAT 2, we want to have a list of all literals"
        return sat2(C, pointer_list=kb, c=c, literal_list=literal_list)
    else:
        kb_original_length = len(kb)
        kb = list(kb)
        for disjunction in c.parts:
            kb.append(Disjunction([disjunction]))

        result = rin08_resolution(kb, start_index=kb_original_length)
        if isinstance(result, Falsity):
            return False
        else:
            return True


def rin08_resolution(kb, start_index=0):
    # Note:
    # kb is a CNF implemented as list<Disjunction>

    def resolvable(left, right):
        for atom in left.parts:
            if atom.negate() in right.parts:
                return atom
        return None

    def generator(previous_kb, resulting_kb, start_index):
        len_previous_kb = len(previous_kb)  # small optimization reduces calls to len by at least 4 per outer loop
        i = start_index
        while i < len_previous_kb + len(resulting_kb):
            j = 0
            while j < len_previous_kb + len(resulting_kb):
                l = get_from_joined_list(i, previous_kb, resulting_kb, len_previous_kb)
                r = get_from_joined_list(j, previous_kb, resulting_kb, len_previous_kb)
                left_atom = resolvable(l, r)
                if left_atom:
                    left_parts = list(l.parts)
                    left_parts.remove(left_atom)
                    right_parts = list(r.parts)
                    right_parts.remove(left_atom.negate())
                    combined_parts = left_parts + right_parts
                    if not combined_parts:
                        yield Falsity()
                    yield Disjunction(sorted(combined_parts))
                j += 1
            i += 1

    result = []
    for resolvent in generator(kb, result, start_index):
        if type(resolvent) is Falsity:
            return resolvent
        if resolvent not in result and resolvent not in kb:
            result.append(resolvent)
    return result


def build_pointerlist(C):
    pointerlist = defaultdict(list)
    for disjunct in C:
        for atom in disjunct.parts:
            pointerlist[atom.negate()].append(disjunct)
    return pointerlist


# initial_forced_true : list<Atom>
def unit_resolution(clauses, clauses_with_neg_occurance: dict, unit_clauses):
    forced_literals = [l for l in unit_clauses.parts]
    literal_marked_forced = set()
    unmarked_literal_per_claus = {claus: len(claus.parts) for claus in clauses}
    for l in unit_clauses.parts:
        if l.negate() in literal_marked_forced:
            return None
        literal_marked_forced.add(l)
    for disjunction in clauses:
        if len(disjunction.parts) == 1:
            l = disjunction.parts[0]
            if l.negate() in literal_marked_forced:
                return None
            else:
                literal_marked_forced.add(l)
                forced_literals.append(l)
    while forced_literals:
        forced_literal = forced_literals.pop()
        for claus_to_reduce in clauses_with_neg_occurance[forced_literal]:
            assert claus_to_reduce in unmarked_literal_per_claus, "Each literal must have a corresponding counter of " \
                                                                  "unmarked literals, " + claus_to_reduce + " did not"
            # one literal in this disjunction is dominated by a forced literal (that is why we added it to the stack)
            unmarked_literal_per_claus[claus_to_reduce] = unmarked_literal_per_claus[claus_to_reduce] - 1
            if unmarked_literal_per_claus[claus_to_reduce] == 1:
                found = 0  # what does found stand for
                # All but one of the literals in this disjunction are already marked
                for l in claus_to_reduce.parts:
                    # add all literals that are not yet in the list
                    # algorithmicaly compensate trivial formulas
                    if l.negate() not in literal_marked_forced:
                        found = found + 1
                        if l in literal_marked_forced:
                            # we don't need to add this as it is already being handled
                            continue
                        forced_literals.append(l)
                        literal_marked_forced.add(l)
                assert found < 2, "appaerently it goes above 1"
                # If the claus contains only literals who's negation is forced true,
                # then the {clausses}+{unitclaus} are not satisfiable
                if not len(claus_to_reduce.parts) - found == 1:
                    return None

    result = []
    for claus, number_unmarked in unmarked_literal_per_claus.items():
        if number_unmarked > 1:
            for l in claus.parts:
                if l in literal_marked_forced:
                    break  # claus fullfilled
            else:
                result.append(claus)
    return result


def sat2(C, pointer_list, c: Conjunction, literal_list):
    # do unit resolution until only binary left
    # print("c: "+str(c.parts))
    # print("C:"+str([c.parts for c in C]))
    binary_disjunctions = unit_resolution(C, pointer_list, c)
    if binary_disjunctions is None:
        return False
    if not binary_disjunctions:
        return True
    tarjan_edges = []
    for disjunction in binary_disjunctions:
        assert len(disjunction.parts) == 2, "a disjunction had length equal " + str(
            len(disjunction.parts)) + " after unitresolution, should not happen"
        a = disjunction.parts[0]
        b = disjunction.parts[1]
        tarjan_edges.append([a.negate(), b])
        tarjan_edges.append([b.negate(), a])
    # do tarjan for sccs
    sccs = tarjan(literal_list, tarjan_edges)
    # is sat if Graph has no Zyklus where a variable is both positive and negated in there
    for scc in sccs:
        for literal in scc:
            if literal.negate() in scc:
                return False
    return True


# A: list[pddl.Atom], I: list[pddl.Atom]
def get_initial_candidates(A, I):
    # Note: does the following expressions form the paper,
    # C = {a for a in A if models(a, I)}
    # C += {not a for a in A if models(not a, I)}
    C = []
    for a in A:
        assert not a.negated, "The list of all atoms should not contain any negated atom"
        if a in I:
            C.append(Disjunction([a]))
        else:
            C.append(Disjunction([a.negate()]))
    return C


counter_smartsuccessors = 0


# A: list[pddl.Atom]
def get_new_candidates(A, C, c: Disjunction):
    # Note does the following expressions form the paper:
    # C = {a for a in A if models(a, I)}
    # C += {not a for a in A if models(not a, I)}
    # optimization:
    # - we only add (a v c) if a is no longer in C. Because a is the stronger invariant and
    #       should a be no invariant, it will be visited later at which point we add (a v c)

    assert type(c) is Disjunction
    new_candidates = []
    for a in A:
        if options.invariant_rin08_opt_smartsuccessors and a in C:
            counter_smartsuccessors = +1
            continue
        new_candidates.append(Disjunction(sorted([a] + list(c.parts))))
        new_candidates.append(Disjunction(sorted([a.negate()] + list(c.parts))))
    return new_candidates


def trivial(candidate):
    set_literals = set(candidate.parts)
    if len(set_literals) == 2:
        if candidate.parts[0].negate() == candidate.parts[1]:
            return True
    return False


def get_affected_atoms(operators):
    affected_atoms = set()
    for o in operators:
        for con, del_eff in o.del_effects:
            affected_atoms.add(del_eff)
        for con, add_eff in o.add_effects:
            if add_eff not in o.precondition:
                affected_atoms.add(add_eff)

    return affected_atoms


# A : Set[pddl.Atom]
# I : List[pddl.Atom] But not all, only the trues and many are not relevant ( z.B. ball(ball2))
# O : List[pddl.PropositionalAction]
# n : integer
def rin08_invariants(atoms, initial_state, operators):  # G-IRIS
    n = options.rintanen_max_invarsize
    assert n >= 1, "Rintanens Invariant Algorithm cannot produce invariants of size less than one"

    remove_trivial_operators(operators)
    affected_atoms = get_affected_atoms(operators)
    if options.invariant_debug:
        opt_details = "optimization: only adding new formulas=" + str(
            options.invariant_rin08_opt_smartsuccessors) + "  directly test new candidates with same operator=" + str(
            options.invariant_rin08_opt_sameoperatortesting)
        start_ts = str(datetime.datetime.now()).replace(":", "").replace(".", "").replace(" ", "_")
        path = "output/" + start_ts
        mkdir(path)
    counter_iteration = 0
    if options.invariant_rin08_opt_sameoperatortesting:
        counter_sameoperatortesting = 0
    # Note: C is a list of lists, 1. Index is the candidate 2. Index are the conjunctive literals
    C = get_initial_candidates(affected_atoms, initial_state)
    C_trivial = get_initial_candidates(set(atoms).difference(affected_atoms), initial_state)
    hashes_considered_canidates = set()
    hashes_added_canidates = set()
    candidate_tests = 0
    list_changed = True
    if SAT2:
        literal_list = [a for a in atoms] + [a.negate() for a in atoms]
    else:
        literal_list = None
    while list_changed:
        list_changed = False
        Ctemp = copy.deepcopy(C)
        if options.invariant_debug:
            with open(path + "/iteration" + str(counter_iteration) + ".txt", "w") as file:
                file.write(
                    "Ctemp in run " + str(counter_iteration) + ":" +
                    str([c.parts for c in Ctemp]).replace("],", "], \n") + "with " + opt_details)
        if not SAT2:
            if options.invariant_debug:
                with timers.timing("build kb"):
                    kb = rin08_resolution(Ctemp)
                print("size full kb:" + str(len(kb)) + "    and size C:" + str(len(C)) + "   with " + opt_details)
            else:
                kb = rin08_resolution(Ctemp)
            kb.extend(Ctemp)
        else:
            if options.invariant_debug:
                with timers.timing("did one iteration"):
                    kb = build_pointerlist(C)
            else:
                kb = build_pointerlist(C)
                # kb2 = rin08_resolution(Ctemp)
                # kb2.extend(Ctemp)
        i = 0
        while i < len(C):
            if time.process_time() - start_time > options.invariant_generation_max_time:
                print("Time limit reached, aborting invariant generation")
                invar_search_output(time.process_time() - start_time, counter_iteration, -1)
                return
            c = C[i]
            hashes_considered_canidates.add(c.hash)
            for o in operators:
                rnn = regress_strips(o, c.negate())
                candidate_tests += 1
                if satisfies(rnn, kb, Ctemp, literal_list):
                    # if not satisfies(rnn, kb2, change=True):
                    #    print("disagreement new sais it is")
                    #    a = satisfies(rnn, kb, Ctemp, literal_list)
                    #    b = satisfies(rnn, kb2, change=True)
                    C.pop(i)
                    list_changed = True
                    if len(c.parts) < n:  # c has less literals than n
                        for candidate in get_new_candidates(affected_atoms, C, c):
                            if trivial(candidate):
                                C_trivial.append(candidate)
                                continue
                            if options.invariant_rin08_opt_sameoperatortesting:
                                candidate_tests += 1
                                if candidate.hash not in hashes_added_canidates and \
                                        not satisfies(regress_strips(o, candidate.negate()), kb, Ctemp, literal_list):
                                    hashes_added_canidates.add(candidate.hash)
                                    C.append(candidate)
                                else:
                                    counter_sameoperatortesting += 1
                            else:
                                C.append(candidate)
                    # Note: i is not increased because we just caused a shift
                    break
            else:
                i += 1
        counter_iteration += 1

    def list_to_string(list_of_disjunctions):
        result = ""
        if list_of_disjunctions:
            for disjunction in list_of_disjunctions:
                for part in disjunction.parts:
                    result += str(part) + ","
                result += "\n"
        return result

    C = C + C_trivial
    invar_search_output(counter_iteration, len(C), len(hashes_considered_canidates),
                        candidate_tests)

    if options.invariant_debug:
        if options.invariant_rin08_opt_sameoperatortesting:
            translator_invariant_log("outer loops saved by same operator optimization", counter_sameoperatortesting)
        if options.invariant_rin08_opt_smartsuccessors:
            translator_invariant_log("additions saved by optimization smart successor", (counter_smartsuccessors * 2))
        with open(path + "/final.txt", "w") as file:
            file.write("Ctemp in last run:" + list_to_string(C))
    return C


def rin08_get_groups(atoms, initial_state, operators):
    with timers.timing("Finding invariants", block=True):
        invariants = rin08_invariants(atoms, initial_state, operators)
    invariants_by_type = invariants_to_types(invariants)

    mutex_edges = get_mutex_edges(invariants_by_type[NAVNB])

    log_number_mutexes(mutex_edges)

    partitioning = partition_to_cliques(nodes=atoms.copy(), edges=mutex_edges)
    return partitioning
