#include "pdb_optimal_constraints.h"

#include "lp_constraint_collection.h"
#include "operator_count_lp.h"

#include "../plugin.h"

#include "../pdbs/pattern_generation_haslum.h"
#include "../pdbs/pattern_generation_systematic.h"
#include "../pdbs/canonical_pdbs_heuristic.h"
#include "../pdbs/pdb_heuristic.h"

using namespace std;

namespace pho {
// TODO avoid code duplication in constructor with PDBConstraints
PDBOptimalConstraints::PDBOptimalConstraints(const Options &opts)
    : cost_type(OperatorCost(opts.get_enum("cost_type"))) {
    if (opts.contains("patterns")) {
        patterns = opts.get<vector<vector<int> > >("patterns");
    } else if (opts.contains("systematic") && opts.get<int>("systematic")) {
        Options generator_opts;
        generator_opts.set<int>("pattern_max_size", opts.get<int>("systematic"));
        generator_opts.set<bool>("dominance_pruning", opts.get<bool>("dominance_pruning"));
        if (opts.contains("prune_irrelevant_patterns") && opts.get<bool>("prune_irrelevant_patterns")) {
            PatternGenerationSystematic pattern_generator(generator_opts);
            patterns = pattern_generator.get_patterns();
        } else {
            PatternGenerationSystematicNaive pattern_generator(generator_opts);
            patterns = pattern_generator.get_patterns();
        }
    }

    if (patterns.empty()) {
        // Compute pattern collection with iPDB as a fall-back option
        // NOTE that we do not reuse/borrow the heuristic objects from iPDB
        //      because we need PDB heuristics to store their transition system
        //      during their creation.
        PatternGenerationHaslum pgh(opts);
        CanonicalPDBsHeuristic *canonical = pgh.get_pattern_collection_heuristic();
        vector<PDBHeuristic *> pdbs = canonical->get_pattern_databases();
        for (size_t i = 0; i < pdbs.size(); ++i) {
            patterns.push_back(pdbs[i]->get_pattern());
        }
        delete canonical;
    }
}


PDBOptimalConstraints::~PDBOptimalConstraints() {
    for (size_t i = 0; i < heuristics.size(); ++i) {
        delete heuristics[i];
    }
}

void PDBOptimalConstraints::initialize_constraints(LPConstraintCollection &constraint_collection, vector<int> &ignore_list) {
    for (size_t i = 0; i < patterns.size(); ++i) {
        if (i % 1000 == 0) {
            cout << "Generated " << i << "/" << patterns.size() << " PDBs" << endl;
        }
        add_pattern(patterns[i], constraint_collection);
    }
    cout << "Generated " << heuristics.size() << "/" << patterns.size() << " PDBs" << endl;
}

void PDBOptimalConstraints::add_pattern(const vector<int> &pattern,
                                        LPConstraintCollection &constraint_collection) {
    Options pdb_opts;
    pdb_opts.set<int>("cost_type", cost_type);
    pdb_opts.set<vector<int> >("pattern", pattern);
    PDBHeuristic *h = new PDBHeuristic(pdb_opts, false, vector<int>(), true);

    // Add variables selected_transition[p][t] for each pattern p and
    // each abstract transition t = <s'', o, s'> of PDB p
    const vector<AbstractPDBTransition> *transitions = h->get_abstract_transitions();
    int transition_variable_offset = constraint_collection.add_variables(
        transitions->size(), LPVariable(0, numeric_limits<double>::infinity(), 0));

    // Add constraints operator_count[o] >= sum_{t in T} selected_transition[p][t]
    // where T = transitions of p labeled with o
    vector<LPConstraint> static_constraints(g_operators.size());
    for (size_t op_id = 0; op_id < g_operators.size(); ++op_id) {
        static_constraints[op_id].insert(op_id, 1.0);
        static_constraints[op_id].set_lower_bound(0);
    }
    for (size_t t_id = 0; t_id < transitions->size(); ++t_id) {
        AbstractPDBTransition t = (*transitions)[t_id];
        int transition_variable = transition_variable_offset + t_id;
        static_constraints[t.op_id].insert(transition_variable, -1.0);
    }

    // Add variables selected_goal[p][s*] for each pattern p and
    // each abstract goal state s* of PDB p
    const vector<int> *goals = h->get_abstract_goal_states();
    int goal_variable_offset = constraint_collection.add_variables(
        goals->size(), LPVariable(0, numeric_limits<double>::infinity(), 0));

    LPConstraint at_least_one_goal_constraint;
    at_least_one_goal_constraint.set_lower_bound(1.0);
    for (size_t g_id = 0; g_id < goals->size(); ++g_id) {
        at_least_one_goal_constraint.insert(goal_variable_offset + g_id, 1.0);
    }
    static_constraints.push_back(at_least_one_goal_constraint);

    // Add constraints
    //            D >= X                        if s' != s*
    //            D - selected_goal[p][s*] >= X if s' == s*
    // with D = sum_{t in IN(s')} selected_transition[p][t] -
    //          sum_{t in OUT(s')} selected_transition[p][t]
    vector<LPConstraint> entry_count_constraints(h->get_size());
    for (size_t t_id = 0; t_id < transitions->size(); ++t_id) {
        AbstractPDBTransition t = (*transitions)[t_id];
        int transition_variable = transition_variable_offset + t_id;
        // + 1 * selected_transition[p][t] for incoming transitions
        entry_count_constraints[t.to_state_index].insert(transition_variable, 1.0);
        // - 1 * selected_transition[p][t] for outgoing transitions
        entry_count_constraints[t.from_state_index].insert(transition_variable, -1.0);
    }
    for (size_t g_id = 0; g_id < goals->size(); ++g_id) {
        // - 1 * selected_goal[p][s*] for goal states
        int s_id = (*goals)[g_id];
        int g_variable = goal_variable_offset + g_id;
        entry_count_constraints[s_id].insert(g_variable, -1.0);
    }
    for (size_t i = 0; i < entry_count_constraints.size(); ++i) {
        entry_count_constraints[i].set_lower_bound(0);
    }

    // Add constraints to LP
    constraint_collection.add_constraints(static_constraints);
    int offset = constraint_collection.add_constraints(entry_count_constraints);
    int initial_state_index = h->hash_index(g_initial_state());

    heuristics.push_back(h);
    entry_count_contraint_offsets.push_back(offset);
    current_abstract_state_constraint_ids.push_back(offset + initial_state_index);

    h->clear_transition_system();
}

bool PDBOptimalConstraints::reach_state(const State &parent_state, const Operator &op,
                                        const State &state) {
    bool h_dirty = false;
    for (size_t i = 0; i < heuristics.size(); ++i) {
        if (heuristics[i]->reach_state(parent_state, op, state)) {
            h_dirty = true;
        }
    }
    return h_dirty;
}

bool PDBOptimalConstraints::update_constraints(const State &state, OperatorCountLP &lp) {
    for (size_t i = 0; i < heuristics.size(); ++i) {
        PDBHeuristic *h = heuristics[i];
        h->evaluate(state);
        if (h->is_dead_end()) {
            return true;
        }
        int old_state_constraint_id = current_abstract_state_constraint_ids[i];
        lp.set_permanent_constraint_lower_bound(
            old_state_constraint_id, 0);
        int new_state_constraint_id = entry_count_contraint_offsets[i] + h->hash_index(state);
        lp.set_permanent_constraint_lower_bound(
            new_state_constraint_id, -numeric_limits<double>::infinity());
        current_abstract_state_constraint_ids[i] = new_state_constraint_id;
    }
    return false;
}

static ConstraintGenerator *_parse(OptionParser &parser) {
    PatternGenerationHaslum::create_options(parser);
    parser.add_option<int>("systematic",
                           "systematically generate all patterns with up to n variables instead of using PatternGenerationHaslum.",
                           "0");
    parser.add_option<bool>("prune_irrelevant_patterns",
                            "prune irrelevant patterns before building the LP.",
                            "true");

    Heuristic::add_options_to_parser(parser);
    Options opts = parser.parse();
    if (parser.help_mode())
        return 0;
    PatternGenerationHaslum::sanity_check_options(parser, opts);

    if (parser.dry_run())
        return 0;
    return new PDBOptimalConstraints(opts);
}

static Plugin<ConstraintGenerator> _plugin("pdb_optimal_constraints", _parse);
}
