#! /usr/bin/env python3


import sys
import itertools
import copy

import normalize
import pddl
import timers

class PrologProgram:
    def __init__(self):
        self.facts = []
        self.rules = []
        self.objects = set()
        def predicate_name_generator():
            for count in itertools.count():
                yield "p$%d" % count
        self.new_name = predicate_name_generator()
    def add_fact(self, atom):
        self.facts.append(Fact(atom))
        self.objects |= set(atom.args)
    def add_rule(self, rule):
        self.rules.append(rule)
    def dump(self, file=None):
        for fact in self.facts:
            print(fact, file=file)
        for rule in self.rules:
            print(getattr(rule, "type", "none"), rule, file=file)
    def normalize(self):
        # Normalized prolog programs have the following properties:
        # 1. Each variable that occurs in the effect of a rule also occurs in its
        #    condition.
        # 2. The variables that appear in each effect or condition are distinct.
        # 3. There are no rules with empty condition.
        self.remove_free_effect_variables()
        self.split_duplicate_arguments()
        # Allow rules with empty condition for operators of optimized hadd.
        #self.convert_trivial_rules()
        for rule in self.rules:
            if not rule.conditions:
                self.add_fact(rule.effect)
    def split_rules(self):
        import split_rules
        # Splits rules whose conditions can be partitioned in such a way that
        # the parts have disjoint variable sets, then split n-ary joins into
        # a number of binary joins, introducing new pseudo-predicates for the
        # intermediate values.
        new_rules = []
        for rule in self.rules:
            new_rules += split_rules.split_rule(rule, self.new_name)
        self.rules = new_rules
    def remove_free_effect_variables(self):
        """Remove free effect variables like the variable Y in the rule
        p(X, Y) :- q(X). This is done by introducing a new predicate
        @object, setting it true for all objects, and translating the above
        rule to p(X, Y) :- q(X), @object(Y).
        After calling this, no new objects should be introduced!"""

        # Note: This should never be necessary for typed domains.
        # Leaving it in at the moment regardless.
        must_add_predicate = False
        for rule in self.rules:
            eff_vars = get_variables([rule.effect])
            cond_vars = get_variables(rule.conditions)
            if not eff_vars.issubset(cond_vars):
                must_add_predicate = True
                eff_vars -= cond_vars
                for var in sorted(eff_vars):
                    rule.add_condition(pddl.Atom("@object", [var]))
        if must_add_predicate:
            print("Unbound effect variables: Adding @object predicate.")
            self.facts += [Fact(pddl.Atom("@object", [obj])) for obj in self.objects]
    def split_duplicate_arguments(self):
        """Make sure that no variable occurs twice within the same symbolic fact,
        like the variable X does in p(X, Y, X). This is done by renaming the second
        and following occurrences of the variable and adding equality conditions.
        For example p(X, Y, X) is translated to p(X, Y, X@0) with the additional
        condition =(X, X@0); the equality predicate must be appropriately instantiated
        somewhere else."""
        printed_message = False
        for rule in self.rules:
            if rule.rename_duplicate_variables() and not printed_message:
                print("Duplicate arguments: Adding equality conditions.")
                printed_message = True

    def convert_trivial_rules(self):
        """Convert rules with an empty condition into facts.
        This must be called after bounding rule effects, so that rules with an
        empty condition must necessarily have a variable-free effect.
        Variable-free effects are the only ones for which a distinction between
        ground and symbolic atoms is not necessary."""
        must_delete_rules = []
        for i, rule in enumerate(self.rules):
            if not rule.conditions:
                assert not get_variables([rule.effect])
                self.add_fact(pddl.Atom(rule.effect.predicate, rule.effect.args))
                must_delete_rules.append(i)
        if must_delete_rules:
            print("Trivial rules: Converted to facts.")
            for rule_no in must_delete_rules[::-1]:
                del self.rules[rule_no]

    def remove_action_predicates(self, task):
        '''
        Remove the action predicates and restructure the Datalog program.
        For example,

        join action_a(?x, ?b) :- p(?x, ?b), r(?x).
        project eff_1(?x) :- action_a(?x, ?b).
        project eff_2(?b) :- action_a(?x, ?b).

        becomes

        join eff_1(?x) :- p(?x, ?b), r(?x).
        join eff_2(?x) :- p(?x, ?b), r(?x).

        This *needs* to be made before the renaming.
        '''

        non_action_rules = []
        action_rules = dict()
        costs = instantiate_costs(task)
        for r in self.rules:
            # Capture action rules and do not add them to the new set of rules
            rule_name = str(r.effect)
            if rule_name.startswith("Atom <Action"):
                for action_id, cost in costs:
                    if hex(action_id) in rule_name:
                        r.weight = cost
                action_rules[rule_name] = r
            else:
                non_action_rules.append(r)

        final_rules = []
        for r in non_action_rules:
            if len(r.conditions) == 1:
                condition_name = str(r.conditions[0])
                if condition_name in action_rules.keys():
                    new_action_rule = copy.deepcopy(action_rules[condition_name])
                    new_action_rule.effect = r.effect
                    final_rules.append(new_action_rule)
                else:
                    final_rules.append(r)
            else:
                final_rules.append(r)
        self.rules = final_rules
    
    def rename_free_variables(self):
        '''
        Use canonical names for free variables. The names are based on the
        order in
        which the variables first show up and not on the PDDL file.
        '''


        def is_free_var(var, num):
            if var[0] != '?':
                return False, 0
            if var not in parameter_to_generic_free_var.keys():
                parameter_to_generic_free_var[var] = "?var" + str(num)
                return True, 1
            else:
                return True, 0

        new_rules = []
        for r in self.rules:
            rule = copy.deepcopy(r)
            parameter_to_generic_free_var = dict()
            num_free_vars = 0
            new_effect = []
            for index, e in enumerate(rule.effect.args):
                is_free, increase = is_free_var(e, num_free_vars)
                if is_free:
                    new_effect.append(parameter_to_generic_free_var[e])
                    num_free_vars += increase
                else:
                    new_effect.append(e)
            rule.effect.args = tuple(new_effect)
            for index, c in enumerate(rule.conditions):
                new_condition = []
                for a in c.args:
                    is_free, increase = is_free_var(a, num_free_vars)
                    if is_free:
                        new_condition.append(parameter_to_generic_free_var[a])
                        num_free_vars += increase
                    else:
                        new_condition.append(a)
                rule.conditions[index].args = tuple(new_condition)
            new_rules.append(rule)
        self.rules = new_rules

    def find_equivalent_rules(self, rules):
        has_duplication = False
        new_rules = []
        remaining_equivalent_rules = dict()
        equivalence = dict()
        for rule in rules:
            if "p$" in str(rule.effect):
                '''Auxiliary variable'''
                if (str(rule.conditions), str(rule.effect.args)) in remaining_equivalent_rules.keys():
                    equivalence[str(rule.effect.predicate)] = remaining_equivalent_rules[(str(rule.conditions), str(rule.effect.args))]
                    has_duplication = True
                    continue
                remaining_equivalent_rules[(str(rule.conditions), str(rule.effect.args))] = rule.effect.predicate
            new_rules.append(rule)
        return has_duplication, new_rules, equivalence

    def remove_duplicated_rules(self):
        '''
        Remove redundant and duplicated rules from the IDB of the Datalog
        '''
        has_duplication = True
        total_rules_removed = 0
        while has_duplication:
            number_removed = 0
            final_rules = []
            has_duplication, new_rules, equivalence = self.find_equivalent_rules(self.rules)
            for rule in new_rules:
                for i, c in enumerate(rule.conditions):
                    pred_symb = str(c.predicate)
                    if pred_symb in equivalence.keys():
                        new_cond = c
                        new_cond.predicate = equivalence[pred_symb]
                        number_removed += 1
                        #print("Replace %s by %s" % (pred_symb, equivalence[pred_symb]))
                        rule.conditions[i] = new_cond
                final_rules.append(rule)
            total_rules_removed += number_removed
            self.rules = final_rules
        print("Total number of duplicated rules removed: %d" % total_rules_removed, file=sys.stderr)

def get_variables(symbolic_atoms):
    variables = set()
    for sym_atom in symbolic_atoms:
        variables |= {arg for arg in sym_atom.args if arg[0] == "?"}
    return variables

def instantiate_costs(task):
        costs = []
        init_assignments = {}
        for element in task.init:
            if isinstance(element, pddl.Assign):
                init_assignments[element.fluent] = element.expression

        for action in task.actions:
                if task.use_min_cost_metric:
                    if action.cost is None:
                        cost = 0
                    else:
                        cost = int(action.cost.instantiate(
                            None, init_assignments).expression.value)
                else:
                    cost = 1
                costs.append([id(action), cost])
        return costs

def remove_duplicate_preconditions_in_actions(task):
    for action in task.actions:
        if isinstance(action.precondition, pddl.Conjunction):
            action.precondition.parts = tuple([condition for index, condition in enumerate(action.precondition.parts) 
                                               if condition not in action.precondition.parts[:index]])

class Fact:
    def __init__(self, atom):
        self.atom = atom
    def __str__(self):
        return "%s." % self.atom

class Rule:
    def __init__(self, conditions, effect, weight=0):
        self.conditions = conditions
        self.effect = effect
        self.weight = weight
    def add_condition(self, condition):
        self.conditions.append(condition)
    def get_variables(self):
        return get_variables(self.conditions + [self.effect])
    def _rename_duplicate_variables(self, atom, new_conditions):
        used_variables = set()
        for i, var_name in enumerate(atom.args):
            if var_name[0] == "?":
                if var_name in used_variables:
                    new_var_name = "%s@%d" % (var_name, len(new_conditions))
                    atom = atom.replace_argument(i, new_var_name)
                    new_conditions.append(pddl.Atom("=", [var_name, new_var_name]))
                else:
                    used_variables.add(var_name)
        return atom
    def rename_duplicate_variables(self):
        extra_conditions = []
        self.effect = self._rename_duplicate_variables(
            self.effect, extra_conditions)
        old_conditions = self.conditions
        self.conditions = []
        for condition in old_conditions:
            self.conditions.append(self._rename_duplicate_variables(
                    condition, extra_conditions))
        self.conditions += extra_conditions
        return bool(extra_conditions)
    def str_weighted(self):
        cond_str = ", ".join(map(str, self.conditions))
        return "%s :- %s [%s]." % (self.effect, cond_str, self.weight)
    def __str__(self):
        cond_str = ", ".join(map(str, self.conditions))
        return "%s :- %s." % (self.effect, cond_str)

def translate_typed_object(prog, obj, type_dict):
    supertypes = type_dict[obj.type_name].supertype_names
    for type_name in [obj.type_name] + supertypes:
        prog.add_fact(pddl.TypedObject(obj.name, type_name).get_atom())

def translate_facts(prog, task):
    type_dict = {type.name: type for type in task.types}
    for obj in task.objects:
        translate_typed_object(prog, obj, type_dict)
    for fact in task.init:
        assert isinstance(fact, pddl.Atom) or isinstance(fact, pddl.Assign)
        if isinstance(fact, pddl.Atom):
            prog.add_fact(fact)
        else:
            # Add a fact to indicate that the primitive numeric expression in
            # fact.fluent has been defined.
            prog.add_fact(normalize.get_pne_definition_predicate(fact.fluent))

def translate(task):
    # Note: The function requires that the task has been normalized.
    with timers.timing("Generating Datalog program"):
        prog = PrologProgram()
        translate_facts(prog, task)
        for conditions, effect in normalize.build_exploration_rules(task):
            prog.add_rule(Rule(conditions, effect))
    with timers.timing("Normalizing Datalog program", block=True):
        # Using block=True because normalization can output some messages
        # in rare cases.
        prog.normalize()
        prog.split_rules()
    return prog

def translate_optimize(task):
    remove_duplicate_preconditions_in_actions(task)
    prog = PrologProgram()
    translate_facts(prog, task)
    for conditions, effect in normalize.build_exploration_rules(task):
        prog.add_rule(Rule(conditions, effect))
    prog.remove_action_predicates(task)
    prog.normalize()
    prog.split_rules()
    prog.rename_free_variables()
    prog.remove_duplicated_rules()
    return prog

if __name__ == "__main__":
    import pddl_parser
    task = pddl_parser.open()
    normalize.normalize(task)
    prog = translate(task)
    prog2 = translate_optimize(task)
    prog.dump()
    with open("translate_optimize.txt", "w") as output_file:
        for rule in prog2.rules:
            print(getattr(rule, "type", "none"), rule.str_weighted(), file=output_file)
    print("Number of rules (original):")
    print(len(prog.rules))
    print("Number of rules (new):")
    print(len(prog2.rules))
