#include "lm_count_constraints.h"

#include "lp_constraint_collection.h"
#include "operator_count_lp.h"

#include "../plugin.h"

#include <vector>

namespace pho {
LMCountConstraints::LMCountConstraints(const Options &opts)
    : lm_graph(*opts.get<LandmarkGraph *>("lm_graph")),
      lm_status_manager(lm_graph) {
    lm_status_manager.set_landmarks_for_initial_state();
}

LMCountConstraints::~LMCountConstraints() {
}

const set<int> &LMCountConstraints::get_achievers(
    int lmn_status, const LandmarkNode &lmn) const {
    // Return relevant achievers of the landmark according to its status.
    if (lmn_status == lm_not_reached)
        return lmn.first_achievers;
    else if (lmn_status == lm_needed_again)
        return lmn.possible_achievers;
    else
        return empty;
}

bool LMCountConstraints::reach_state(const State &parent_state, const Operator &op,
                                     const State &state) {
    lm_status_manager.update_reached_lms(parent_state, op, state);
    return true;
}

bool LMCountConstraints::update_constraints(const State &state, OperatorCountLP &lp) {
    // Need explicit test to see if state is a goal state. The landmark
    // heuristic may compute h != 0 for a goal state if landmarks are
    // achieved before their parents in the landmarks graph (because
    // they do not get counted as reached in that case). However, we
    // must return 0 for a goal state.
    bool goal_reached = test_goal(state);
    if (goal_reached) {
        return false;
    }
    bool dead_end = lm_status_manager.update_lm_status(state);
    if (dead_end) {
        return true;
    }
    // Add one temporary constraint sum_{o achieves LM} X_o >= 1 for every landmark LM.
    vector<LPConstraint> constraints;
    int num_landmarks = lm_graph.number_of_landmarks();
    for (int lm_id = 0; lm_id < num_landmarks; ++lm_id) {
        const LandmarkNode *lm = lm_graph.get_lm_for_index(lm_id);
        int lm_status = lm->get_status();
        if (lm_status != lm_reached) {
            const set<int> &achievers = get_achievers(lm_status, *lm);
            assert(!achievers.empty());
            LPConstraint constraint;
            set<int>::const_iterator ach_it;
            for (ach_it = achievers.begin(); ach_it != achievers.end();
                 ++ach_it) {
                int op_id = *ach_it;
                assert(op_id >= 0 && op_id < g_operators.size());
                constraint.insert(op_id, 1.0);
            }
            constraint.set_lower_bound(1.0);
            constraints.push_back(constraint);
        }
    }
    lp.add_temporary_constraints(constraints);
    return false;
}

static ConstraintGenerator *_parse(OptionParser &parser) {
    Heuristic::add_options_to_parser(parser);
    parser.add_option<LandmarkGraph *>("lm_graph");
    Options opts = parser.parse();
    if (parser.dry_run())
        return 0;
    return new LMCountConstraints(opts);
}

static Plugin<ConstraintGenerator> _plugin("lm_count_constraints", _parse);
}
