#include "split_selector.h"

#include "abstraction.h"
#include "abstract_state.h"
#include "transition_system.h"
#include "abstract_search.h"
#include "utils.h"

#include "../heuristics/additive_heuristic.h"

#include "../task_utils/task_properties.h"
#include "../utils/logging.h"
#include "../utils/rng.h"

#include <cassert>
#include <iostream>
#include <limits>
#include <unordered_set>

using namespace std;

namespace cegar {
SplitSelector::SplitSelector(
    const shared_ptr<AbstractTask> &task,
    PickSplit pick)
    : task(task),
      task_proxy(*task),
      pick(pick),
      num_picks(0),
      num_options(0),
      num_ratings(0) {
    if (pick == PickSplit::MIN_HADD || pick == PickSplit::MAX_HADD) {
        additive_heuristic = create_additive_heuristic(task);
        additive_heuristic->compute_heuristic_for_cegar(
            task_proxy.get_initial_state());
    }
    if (pick == PickSplit::MIN_GOAL_DIST || pick == PickSplit::MAX_GOAL_DIST
        || pick == PickSplit::MIN_HIGHER_DIST || pick == PickSplit::MAX_HIGHER_DIST) {
        // precompute operator costs if necessary
        op_costs = utils::make_unique_ptr<vector<int>>(task_properties::get_operator_costs(task_proxy));
    }
}

// Define here to avoid include in header.
SplitSelector::~SplitSelector() {
}

int SplitSelector::get_num_unwanted_values(
    const AbstractState &state, const Split &split) const {
    int num_unwanted_values = state.count(split.var_id) - split.values.size();
    assert(num_unwanted_values >= 1);
    return num_unwanted_values;
}

double SplitSelector::get_refinedness(const AbstractState &state, int var_id) const {
    double all_values = task_proxy.get_variables()[var_id].get_domain_size();
    assert(all_values >= 2);
    double remaining_values = state.count(var_id);
    assert(2 <= remaining_values && remaining_values <= all_values);
    double refinedness = -(remaining_values / all_values);
    assert(-1.0 <= refinedness && refinedness < 0.0);
    return refinedness;
}

int SplitSelector::get_hadd_value(int var_id, int value) const {
    assert(additive_heuristic);
    int hadd = additive_heuristic->get_cost_for_cegar(var_id, value);
    assert(hadd != -1);
    return hadd;
}

int SplitSelector::get_min_hadd_value(int var_id, const vector<int> &values) const {
    int min_hadd = numeric_limits<int>::max();
    for (int value : values) {
        const int hadd = get_hadd_value(var_id, value);
        if (hadd < min_hadd) {
            min_hadd = hadd;
        }
    }
    return min_hadd;
}

int SplitSelector::get_max_hadd_value(int var_id, const vector<int> &values) const {
    int max_hadd = -1;
    for (int value : values) {
        const int hadd = get_hadd_value(var_id, value);
        if (hadd > max_hadd) {
            max_hadd = hadd;
        }
    }
    return max_hadd;
}

double SplitSelector::get_avg_goal_dist(const Abstraction &abstraction,
                                        const AbstractState &state,
                                        const Split &split) const {
    // copy the states so we don't destroy the existing ones
    pair<CartesianSet, CartesianSet> cartesian_sets = state.split_domain(split.var_id, split.values);
    AbstractState v1(state.get_id(), -1, move(cartesian_sets.first));
    AbstractState v2(abstraction.get_num_states(), -2, move(cartesian_sets.second));
    TransitionSystem t_copy(abstraction.get_transition_system());
    // t_copy.rewire(<states>, state.get_id(), v1, v2, split.var_id);
    // delegate rewiring to abstraction so we do not have to copy the states
    abstraction.rewire(t_copy, state.get_id(), v1, v2, split.var_id);
    vector<int> goal_dist = compute_distances(t_copy.get_incoming_transitions(),
                                              *op_costs, abstraction.get_goals());

    double avg_dist = 0, norm = 0;
    for (int s_id = 0; s_id <= abstraction.get_num_states(); s_id++) {
        // avoid degenerate case where any one state cannot reach the goal
        if (goal_dist[s_id] != INF) {
            // avoids copying the state vector
            const AbstractState *curr;
            if (s_id == state.get_id())
                curr = &v1;
            else if (s_id == abstraction.get_num_states())
                curr = &v2;
            else
                curr = &abstraction.get_state(s_id);
            double part = 1;
            for (int var = 0; var < curr->var_count(); var++) {
                // avoid overflow for large domains
                part *= static_cast<double>(curr->count(var)) / curr->size(var);
            }
            avg_dist += goal_dist[s_id] * part;
            norm += part;
        }
    }
    if (norm == 0)
        return numeric_limits<double>::infinity();
    return avg_dist / norm;
}

int SplitSelector::get_num_inc_goal_dist(const Abstraction &abstraction,
                                         const AbstractState &state,
                                         const Split &split) const {
    TransitionSystem t_copy(abstraction.get_transition_system());
    vector<int> old_goal_dist = compute_distances(t_copy.get_incoming_transitions(),
                                                  *op_costs, abstraction.get_goals());
    pair<CartesianSet, CartesianSet> cartesian_sets = state.split_domain(split.var_id, split.values);
    AbstractState v1(state.get_id(), -1, move(cartesian_sets.first));
    AbstractState v2(abstraction.get_num_states(), -2, move(cartesian_sets.second));
    abstraction.rewire(t_copy, state.get_id(), v1, v2, split.var_id);
    vector<int> new_goal_dist = compute_distances(t_copy.get_incoming_transitions(),
                                                  *op_costs, abstraction.get_goals());

    int num_increased = 0;
    for (int s_id = 0; s_id < abstraction.get_num_states(); s_id++) {
        // includes split state, it's other split target is checked outside the loop
        if (new_goal_dist[s_id] > old_goal_dist[s_id])
            num_increased++;
    }
    if (new_goal_dist[abstraction.get_num_states()] > old_goal_dist[state.get_id()])
        num_increased++;
    return num_increased;
}

size_t SplitSelector::get_num_active_ops(const Abstraction &abstraction,
                                         const AbstractState &state,
                                         const Split &split) const {
    pair<CartesianSet, CartesianSet> cartesian_sets = state.split_domain(split.var_id, split.values);
    AbstractState v1(state.get_id(), -1, move(cartesian_sets.first));
    AbstractState v2(abstraction.get_num_states(), -2, move(cartesian_sets.second));
    TransitionSystem t_copy(abstraction.get_transition_system());
    abstraction.rewire(t_copy, state.get_id(), v1, v2, split.var_id);
    std::unordered_set<int> active;
    // only require either incoming or outgoing to get the operators
    auto incoming = t_copy.get_incoming_transitions();
    for (auto it = incoming.begin(); it != incoming.end(); it++) {
        for (auto t = it->begin(); t != it->end(); t++) {
            active.insert(t->op_id);
        }
    }
    return active.size();
}

double SplitSelector::rate_split(const Abstraction &abstraction,
                                 const AbstractState &state,
                                 const Split &split) const {
    int var_id = split.var_id;
    const vector<int> &values = split.values;
    double rating;
    switch (pick) {
    case PickSplit::MIN_UNWANTED:
        rating = -get_num_unwanted_values(state, split);
        break;
    case PickSplit::MAX_UNWANTED:
        rating = get_num_unwanted_values(state, split);
        break;
    case PickSplit::MIN_REFINED:
        rating = -get_refinedness(state, var_id);
        break;
    case PickSplit::MAX_REFINED:
        rating = get_refinedness(state, var_id);
        break;
    case PickSplit::MIN_HADD:
        rating = -get_min_hadd_value(var_id, values);
        break;
    case PickSplit::MAX_HADD:
        rating = get_max_hadd_value(var_id, values);
        break;
    case PickSplit::MIN_CG:
        rating = -var_id;
        break;
    case PickSplit::MAX_CG:
        rating = var_id;
        break;
    case PickSplit::MIN_GOAL_DIST:
        rating = -get_avg_goal_dist(abstraction, state, split);
        break;
    case PickSplit::MAX_GOAL_DIST:
        rating = get_avg_goal_dist(abstraction, state, split);
        break;
    case PickSplit::MIN_HIGHER_DIST:
        rating = -get_num_inc_goal_dist(abstraction, state, split);
        break;
    case PickSplit::MAX_HIGHER_DIST:
        rating = get_num_inc_goal_dist(abstraction, state, split);
        break;
    case PickSplit::MIN_ACTIVE_OPS:
        rating = -get_num_active_ops(abstraction, state, split);
        break;
    case PickSplit::MAX_ACTIVE_OPS:
        rating = get_num_active_ops(abstraction, state, split);
        break;
    default:
        cout << "Invalid pick strategy: " << static_cast<int>(pick) << endl;
        utils::exit_with(utils::ExitCode::SEARCH_INPUT_ERROR);
    }
    return rating;
}

const Split &SplitSelector::pick_split(const Abstraction &abstraction,
                                       const AbstractState &state,
                                       const vector<Split> &splits,
                                       utils::RandomNumberGenerator &rng) {
    assert(!splits.empty());
    num_picks++;
    num_options += splits.size();

    if (splits.size() == 1) {
        num_ratings += 1;
        return splits[0];
    }

    if (pick == PickSplit::RANDOM) {
        num_ratings += 1;
        return *rng.choose(splits);
    }

    // used for counting number of distinct ratings
    std::unordered_set<double> ratings;
    double max_rating = numeric_limits<double>::lowest();
    const Split *selected_split = nullptr;
    for (const Split &split : splits) {
        double rating = rate_split(abstraction, state, split);
        ratings.insert(rating);
        if (rating > max_rating) {
            selected_split = &split;
            max_rating = rating;
        }
    }
    num_ratings += ratings.size();
    assert(selected_split);
    if (pick == PickSplit::MIN_CG || pick == PickSplit::MAX_CG)
        assert(ratings.size() == splits.size());
    return *selected_split;
}

void SplitSelector::print_statistics() const {
    if (num_picks > 0) {
        // avoid creating a NaN virus
        double avg_options = static_cast<double>(num_options) / num_picks;
        double avg_ratings = static_cast<double>(num_ratings) / num_picks;
        cout << "Average number of possible splits: " << avg_options << endl;
        cout << "Average number of distinct ratings: " << avg_ratings << endl;
    }
}
}
