#include "pattern_generation_systematic.h"

#include "canonical_pdbs_heuristic.h"

#include "../causal_graph.h"
#include "../globals.h"
#include "../plugin.h"


#include <algorithm>
#include <iostream>

using namespace std;
using namespace std::tr1;


static bool patterns_are_disjoint(const Pattern &pattern1, const Pattern &pattern2) {
    size_t i = 0;
    size_t j = 0;
    for (;;) {
        if (i == pattern1.size() || j == pattern2.size())
            return true;
        int val1 = pattern1[i];
        int val2 = pattern2[j];
        if (val1 == val2)
            return false;
        else if (val1 < val2)
            ++i;
        else
            ++j;
    }
}


static void compute_union_pattern(const Pattern &pattern1, const Pattern &pattern2, Pattern &result) {
    result.clear();
    result.reserve(pattern1.size() + pattern2.size());
    set_union(pattern1.begin(), pattern1.end(),
              pattern2.begin(), pattern2.end(),
              back_inserter(result));
}


PatternGenerationSystematic::PatternGenerationSystematic(const Options &opts)
    : max_pattern_size(opts.get<int>("pattern_max_size")) {
    build_patterns();
}


PatternGenerationSystematic::~PatternGenerationSystematic() {
}


void PatternGenerationSystematic::compute_eff_pre_neighbors(
    const Pattern &pattern, vector<int> &result) const {
    /*
      Compute all variables that are reachable from pattern by an
      (eff, pre) arc and are not already contained in the pattern.
    */

    unordered_set<int> candidates;

    // Compute neighbors.
    for (size_t i = 0; i < pattern.size(); ++i) {
        int var = pattern[i];
        const vector<int> &neighbors = g_causal_graph->get_eff_to_pre(var);
        candidates.insert(neighbors.begin(), neighbors.end());
    }

    // Remove elements of pattern.
    for (size_t i = 0; i < pattern.size(); ++i) {
        int var = pattern[i];
        candidates.erase(var);
    }

    result.assign(candidates.begin(), candidates.end());
}

void PatternGenerationSystematic::compute_connection_points(
    const Pattern &pattern, vector<int> &result) const {
    /*
      The "connection points" of a pattern are those variables of which
      one must be contained in an SGA pattern that can be attached to this
      pattern to form a larger interesting pattern. (Interesting patterns
      are disjoint unions of SGA patterns.)

      A variable is a connection point if it satisfies the following criteria:
      1. We can get from the pattern to the connection point via
         an (eff, pre) or (eff, eff) arc in the causal graph.
      2. It is not part of pattern.
      3. We *cannot* get from the pattern to the connection point
         via an (eff, pre) arc.

      Condition 1. is the important one. The other conditions are
      optimizations that help reduce the number of candidates to
      consider.
    */
    unordered_set<int> candidates;

    // Handle rule 1.
    for (size_t i = 0; i < pattern.size(); ++i) {
        int var = pattern[i];
        const vector<int> &pred = g_causal_graph->get_predecessors(var);
        candidates.insert(pred.begin(), pred.end());
    }

    // Handle rules 2 and 3.
    for (size_t i = 0; i < pattern.size(); ++i) {
        int var = pattern[i];
        // Rule 2:
        candidates.erase(var);
        // Rule 3:
        const vector<int> &eff_pre = g_causal_graph->get_eff_to_pre(var);
        for (size_t j = 0; j < eff_pre.size(); ++j)
            candidates.erase(eff_pre[j]);
    }

    result.assign(candidates.begin(), candidates.end());
}


void PatternGenerationSystematic::enqueue_pattern_if_new(const Pattern &pattern) {
    if (pattern_set.insert(pattern).second)
        patterns.push_back(pattern);
}


void PatternGenerationSystematic::build_sga_patterns() {
    assert(max_pattern_size >= 1);
    assert(pattern_set.empty());
    assert(patterns.empty());

    /*
      SGA patterns are "single-goal ancestor" patterns, i.e., those
      patterns which can be generated by following eff/pre arcs from a
      single goal variable.

      This method must generate all SGA patterns up to size
      "max_pattern_size". They must be generated in order of
      increasing size, and they must be placed in "patterns".

      The overall structure of this is a similar processing queue as
      in the main pattern generation method below, and we reuse
      "patterns" and "pattern_set" between the two methods.
    */

    // Build goal patterns.
    for (size_t i = 0; i < g_goal.size(); ++i) {
        int var = g_goal[i].first;
        Pattern goal_pattern;
        goal_pattern.push_back(var);
        enqueue_pattern_if_new(goal_pattern);
    }

    /*
      Grow SGA patterns until all patterns are processed. Note that
      the patterns vectors grows during the computation.
    */
    for (size_t pattern_no = 0; pattern_no < patterns.size(); ++pattern_no) {
        // We must copy the pattern because references to patterns can be invalidated.
        Pattern pattern = patterns[pattern_no];

        if (pattern.size() == max_pattern_size)
            break;

        vector<int> neighbors;
        compute_eff_pre_neighbors(pattern, neighbors);
		
		//double chance = 1.0 / (pattern.size() + 1);
        for (size_t i = 0; i < neighbors.size(); ++i) {
			//double random = g_rng(); // [0..1)
			//if(random < chance){
				int neighbor_var = neighbors[i];
				Pattern new_pattern(pattern);
				new_pattern.push_back(neighbor_var);
				sort(new_pattern.begin(), new_pattern.end());

				enqueue_pattern_if_new(new_pattern);
			//}
        }
    }

    pattern_set.clear();
}


void PatternGenerationSystematic::build_patterns() {
    int num_variables = g_variable_domain.size();

    // Generate SGA (single-goal-ancestor) patterns.
    // They are generated into the patterns variable,
    // so we swap them from there.
    build_sga_patterns();
    vector<Pattern> sga_patterns;
    patterns.swap(sga_patterns);

    /* Index the SGA patterns by variable.

       Important: sga_patterns_by_var[var] must be sorted by size.
       This is guaranteed because build_sga_patterns generates
       patterns ordered by size.
    */
    vector<vector<const Pattern *> > sga_patterns_by_var(num_variables);
    for (size_t i = 0; i < sga_patterns.size(); ++i) {
        const Pattern &pattern = sga_patterns[i];
        for (size_t j = 0; j < pattern.size(); ++j) {
            int var = pattern[j];
            sga_patterns_by_var[var].push_back(&pattern);
        }
    }

    // Enqueue the SGA patterns.
    for (size_t i = 0; i < sga_patterns.size(); ++i)
        enqueue_pattern_if_new(sga_patterns[i]);


    cout << "Found " << sga_patterns.size() << " SGA patterns." << endl;

    /*
      Combine patterns in the queue with SGA patterns until all
      patterns are processed. Note that the patterns vectors grows
      during the computation.
    */
    for (size_t pattern_no = 0; pattern_no < patterns.size(); ++pattern_no) {
        // We must copy the pattern because references to patterns can be invalidated.
        Pattern pattern1 = patterns[pattern_no];

        vector<int> neighbors;
        compute_connection_points(pattern1, neighbors);

        for (size_t i = 0; i < neighbors.size(); ++i) {
            int neighbor_var = neighbors[i];
            const vector<const Pattern *> &candidates = sga_patterns_by_var[neighbor_var];
            for (size_t j = 0; j < candidates.size(); ++j) {
                const Pattern &pattern2 = *candidates[j];
                if (pattern1.size() + pattern2.size() > max_pattern_size)
                    break;  // All remaining candidates are too large.
                if (patterns_are_disjoint(pattern1, pattern2)) {
                    Pattern new_pattern;
                    compute_union_pattern(pattern1, pattern2, new_pattern);
                    enqueue_pattern_if_new(new_pattern);
                }
            }
        }
    }

    pattern_set.clear();
    cout << "Found " << patterns.size() << " interesting patterns." << endl;

    /*
    cout << "list of patterns:" << endl;
    for (size_t i = 0; i < patterns.size(); ++i)
        cout << patterns[i] << endl;
    */
}


const vector<Pattern> &PatternGenerationSystematic::get_patterns() const {
    return patterns;
}


CanonicalPDBsHeuristic *PatternGenerationSystematic::get_pattern_collection_heuristic(const Options &opts) const {
    Options canonical_opts;
    canonical_opts.set<int>("cost_type", OperatorCost(opts.get<int>("cost_type")));
    canonical_opts.set("patterns", patterns);
    CanonicalPDBsHeuristic *h = new CanonicalPDBsHeuristic(canonical_opts);
    if (opts.get<bool>("dominance_pruning")) {
        h->dominance_pruning();
    }
    return h;
}

PatternGenerationSystematicNaive::PatternGenerationSystematicNaive(const Options &opts) {
    int pattern_max_size = opts.get<int>("pattern_max_size");
    int num_variables = g_variable_domain.size();
    vector<vector<int> > current_patterns(1);
    vector<vector<int> > next_patterns;
    for (size_t i = 0; i < pattern_max_size; ++i) {
        for (size_t j = 0; j < current_patterns.size(); ++j) {
            int max_var = -1;
            if (i > 0)
                max_var = current_patterns[j].back();
            for (size_t var = max_var + 1; var < num_variables; ++var) {
                vector<int> pattern = current_patterns[j];
                pattern.push_back(var);
                next_patterns.push_back(pattern);
                patterns.push_back(pattern);
            }
        }
        next_patterns.swap(current_patterns);
        next_patterns.clear();
    }

    cout << "Found " << patterns.size() << " patterns." << endl;
}

PatternGenerationSystematicNaive::~PatternGenerationSystematicNaive() {
}

CanonicalPDBsHeuristic *PatternGenerationSystematicNaive::get_pattern_collection_heuristic(const Options &opts) const {
    Options canonical_opts;
    canonical_opts.set<int>("cost_type", OperatorCost(opts.get<int>("cost_type")));
    canonical_opts.set("patterns", patterns);
    CanonicalPDBsHeuristic *h = new CanonicalPDBsHeuristic(canonical_opts);
    if (opts.get<bool>("dominance_pruning")) {
        h->dominance_pruning();
    }
    return h;
}



static ScalarEvaluator *_parse(OptionParser &parser) {
    parser.add_option<int>("pattern_max_size",
                           "max number of variables per pattern",
                           "1");
    parser.add_option<bool>("dominance_pruning",
                            "Use dominance pruning to reduce number of patterns.",
                            "true");
    parser.add_option<bool>("prune_irrelevant_patterns",
                            "Prune irrelevant patterns before building the LP.",
                            "true");
    Heuristic::add_options_to_parser(parser);
    Options opts = parser.parse();
    if (opts.get<int>("pattern_max_size") < 1)
        parser.error("number of variables per pattern must be at least 1");

    if (parser.dry_run())
        return 0;
    if (opts.contains("prune_irrelevant_patterns") && opts.get<bool>("prune_irrelevant_patterns")) {
        PatternGenerationSystematic pattern_generator(opts);
        return pattern_generator.get_pattern_collection_heuristic(opts);
    } else {
        PatternGenerationSystematicNaive pattern_generator(opts);
        return pattern_generator.get_pattern_collection_heuristic(opts);
    }
}


static Plugin<ScalarEvaluator> _plugin("systematic_canonical", _parse);
