#include "pdb_constraints.h"

#include "lp_constraint_collection.h"
#include "operator_count_lp.h"

#include "../plugin.h"

#include "../pdbs/pattern_generation_haslum.h"
#include "../pdbs/pattern_generation_systematic.h"
#include "../pdbs/canonical_pdbs_heuristic.h"
#include "../pdbs/pdb_heuristic.h"

#include "../rng.h"
#include "../causal_graph.h"
#include <math.h>

// for eval
#include "posthoc_optimization_heuristic.h"
#include "../max_heuristic.h"
#include "../globals.h"
#include "../state.h"
#include "../state_registry.h"
#include "../successor_generator.h"
namespace pho {

	PDBConstraints::PDBConstraints(const Options &opts)
	: cost_type(OperatorCost(opts.get_enum("cost_type"))),
	canonical(0),
	size(0),
        num_pdb_limit(0)
	{
		vector<vector<int> > patterns;
		int maxSystematic = 0, size_limit = 0;
		float ratioLimit = 0;
		bool exactMaxSystematic = true;
		bool randomize = false;
                float pho_eval = 0;
		if (opts.contains("patterns")) {
			patterns = opts.get<vector<vector<int> > >("patterns");
		} else if (opts.contains("systematic") && opts.get<float>("systematic") > 0) {
			Options generator_opts;
			cout << "Systematic max size " << opts.get<float>("systematic") << endl;
			if ((int) (opts.get<float>("systematic")) == opts.get<float>("systematic")) {
				maxSystematic = (int) (opts.get<float>("systematic"));
				ratioLimit = 0;
			} else {
				maxSystematic = (int) (opts.get<float>("systematic")) + 1;
				ratioLimit = opts.get<float>("systematic") - maxSystematic + 1;
				exactMaxSystematic = false;
				cout << " (last layer: " << maxSystematic << ", ratio: " << ratioLimit << ")";
			}
                        if(opts.contains("num_pdb_limit")){
                            num_pdb_limit = opts.get<int>("num_pdb_limit");
                        }
			size_limit = opts.get<int>("size_limit");
			cout << "Total PDBs size limit " << size_limit << endl;
			generator_opts.set<int>("pattern_max_size", maxSystematic);
			generator_opts.set<bool>("dominance_pruning", opts.get<bool>("dominance_pruning"));
			if (opts.contains("prune_irrelevant_patterns") && opts.get<bool>("prune_irrelevant_patterns")) {
				PatternGenerationSystematic pattern_generator(generator_opts);
				patterns = pattern_generator.get_patterns();
			} else {
				PatternGenerationSystematicNaive pattern_generator(generator_opts);
				patterns = pattern_generator.get_patterns();
			}
		}
                
                if(opts.contains("pho_eval")) {
                    pho_eval = opts.get<float>("pho_eval");
                }

		// Set random seed
		if (opts.contains("random_seed")) {
			g_rng.seed(opts.get<int>("random_seed"));
			cout << "Random seed: " << opts.get<int>("random_seed") << endl;
			randomize = true;
		}

		cout << "Number of operators " << g_operators.size() << endl;
		int num_vars = g_variable_name.size();
		long num_arcs = 0;
		for (int var = 0; var < num_vars; ++var) {
			num_arcs += g_causal_graph->get_successors(var).size();
		}
		cout << "Number of nodes in causal graph " << num_vars << endl;
		cout << "Number of edges in causal graph " << num_arcs << endl;


		if (!patterns.empty()) {
			int currentSize = 0, sizeI = 0;
			float ratio = 0;
			Timer measureTimer;
			const int num_pattern = patterns.size();
			cout << "Number of PDBs limit " << num_pdb_limit << endl;
                        
                        vector<vector<int> > current_patterns;
                        vector<State> samples;
                        int pruning_samples = 20 + 1;
                        if(pho_eval > 0){
                            // calculate average operator costs
                            double average_operator_cost = 0;
                            for (size_t i = 0; i < g_operators.size(); ++i) {
                                    average_operator_cost += get_adjusted_action_cost(g_operators[i], cost_type);
                            }
                            average_operator_cost /= g_operators.size();
                            cout << "Average operator cost: " << average_operator_cost << endl;
                            
                            Options sample_opts;
                            sample_opts.set<int>("cost_type", cost_type);
                            HSPMaxHeuristic *sample_heuristic = new HSPMaxHeuristic(sample_opts);
                            cout << "Sampling " << pruning_samples << " states" << endl;
                            sample_states(samples, pruning_samples, average_operator_cost, sample_heuristic);
                            delete sample_heuristic;
                        }
                        
			for (size_t i = 0; i < num_pattern && (num_pdb_limit==0 || i < num_pdb_limit); ++i) {
				if (maxSystematic > 0) {
					if (patterns[i].size() > currentSize) {
						currentSize = patterns[i].size();
						sizeI = i;
						cout << "Current pattern size: " << currentSize << ", i=" << i << endl;
						if (randomize && !exactMaxSystematic && maxSystematic > 0 && currentSize == maxSystematic) {
							// Reached last size layer -> shuffle remaining patterns
							for (int ii = patterns.size() - 1; ii > i; --ii) {
								int swp = g_rng.next(ii - i + 1) + i;
								vector<int> tmp = patterns[ii];
								patterns[ii] = patterns[swp];
								patterns[swp] = tmp;
							}
						}
					} else if (patterns[i].size() < currentSize) {
						cout << "[ERROR] Patterns not sorted on size." << endl;
					}
				}

				// Stop if ratio is reached
				if (!exactMaxSystematic && maxSystematic > 0 && currentSize == maxSystematic) {
					ratio = (float) (i - sizeI + 1) / (patterns.size() - sizeI);
					//cout << "Ratio: " << ratio << " " << (i - sizeI) << " of " << (patterns.size() - sizeI) << endl;
					if (ratio > ratioLimit) {
						break;
					}
					cout << "Pattern: " << patterns[i] << endl;
				}
				
				

				int num_states = 1;
				for (size_t ii = 0; ii < patterns[i].size(); ++ii) {
					num_states *= g_variable_domain[patterns[i][ii]];
				}

				if (size_limit > 0 && size + num_states >= size_limit) {
					break;
				}
                                
                                // Evaluate speed
                                if(pho_eval > 0){
                                    current_patterns.push_back(patterns[i]);
                                    if(i % 10 == 0){
                                        Options cg_opts;
                                        cg_opts.set<int>("cost_type", cost_type);
                                        cg_opts.set<vector<vector<int> > >("patterns", current_patterns);
                                        ConstraintGenerator* cg = new PDBConstraints(cg_opts);
                                        vector<ConstraintGenerator*> cgs;
                                        cgs.push_back(cg);

                                        Options pho_opts;
                                        pho_opts.set<int>("lpsolver", 0); // HACK
                                        pho_opts.set<int>("cost_type", cost_type);
                                        pho_opts.set<bool>("merge_lp_variables", false);
                                        pho_opts.set<int>("pruning_samples", 0);
                                        pho_opts.set<vector<ConstraintGenerator*> >("constraint_generators", cgs);
                                        cout << "Creating temporary pho-heuristic" << endl; // for easier reading of output
                                        PosthocOptimizationHeuristic *pho = new PosthocOptimizationHeuristic(pho_opts);

                                        pho->evaluate(samples[0]);
                                        Timer eval_time;
                                        for (size_t sample = 1; sample < pruning_samples; ++sample) {
                                            pho->evaluate(samples[sample]);
                                        }
                                        
                                        // Cleanup
                                        delete pho;
                                        delete cg;
                                        
                                        // Stop if evaluation takes too long
                                        double avg_eval_time = eval_time.stop()/(pruning_samples - 1);
                                        cout << "[INFO] avg eval time: " << avg_eval_time << endl;
                                        if(avg_eval_time > pho_eval){
                                            break;
                                        }
                                    }
                                }

				Options pdb_opts;
				pdb_opts.set<int>("cost_type", cost_type);
				pdb_opts.set<vector<int> >("pattern", patterns[i]);
				PDBHeuristic *pdb = new PDBHeuristic(pdb_opts, false);
				size += pdb->get_size();
				
				heuristics.push_back(pdb);
				borrowed_heuristics = false;
                                
                                
                                
			}
			cout << "[measure] Pattern gen time " << measureTimer << endl;
			cout << "Generated " << heuristics.size() << "/" << patterns.size() << " PDBs" << endl;
			cout << "Total size of all PDBs " << size << endl;
		} else {
			// compute pattern collection
			PatternGenerationHaslum pgh(opts);
			canonical = pgh.get_pattern_collection_heuristic();
			heuristics = canonical->get_pattern_databases();
			borrowed_heuristics = true;
			size = canonical->get_size();
		}
	}

        void PDBConstraints::sample_states(vector<State> &samples, int num_samples, double average_operator_cost, Heuristic *heuristic)
	{
		const State &initial_state = g_initial_state();
		heuristic->evaluate(initial_state);
		assert(!heuristic->is_dead_end());

		int h = heuristic->get_heuristic();
		int n;
		if (h == 0) {
			n = 10;
		} else {
			// Convert heuristic value into an approximate number of actions
			// (does nothing on unit-cost problems).
			// average_operator_cost cannot equal 0, as in this case, all operators
			// must have costs of 0 and in this case the if-clause triggers.
			int solution_steps_estimate = int((h / average_operator_cost) + 0.5);
			n = 4 * solution_steps_estimate;
		}
		double p = 0.5;
		// The expected walk length is np = 2 * estimated number of solution steps.
		// (We multiply by 2 because the heuristic is underestimating.)

		int count_samples = 0;
		samples.reserve(num_samples);

		State current_state(initial_state);
		
		int rng_count = 0;
		cout << "num_samples: " << num_samples << endl;
		cout << "n: " << n << endl;
		cout << "average operator cost: " << average_operator_cost << endl;

		for (; count_samples < num_samples; ++count_samples) {
			// calculate length of random walk accoring to a binomial distribution
			int length = 0;
			for (int j = 0; j < n; ++j) {
				double random = g_rng(); // [0..1)
				rng_count++;
				if (random < p)
					++length;
			}

			// random walk of length length
			State current_state(initial_state);
			for (int j = 0; j < length; ++j) {
				vector<const Operator *> applicable_ops;
				g_successor_generator->generate_applicable_ops(current_state, applicable_ops);
				// if there are no applicable operators --> do not walk further
				if (applicable_ops.empty()) {
					break;
				} else {
					int random = g_rng.next(applicable_ops.size()); // [0..applicable_os.size())
					rng_count++;
					assert(applicable_ops[random]->is_applicable(current_state));
					// TODO for now, we only generate registered successors. This is a temporary state that
					// should should not necessarily be registered in the global registry: see issue386.
					current_state = g_state_registry->get_successor_state(current_state, *applicable_ops[random]);
					// if current state is a dead end, then restart with initial state
					heuristic->evaluate(current_state); // TODO: only evaluate if dead end or not
					if (heuristic->is_dead_end())
						current_state = initial_state;
				}
			}
			// last state of the random walk is used as sample
			samples.push_back(current_state);
		}
		
		cout << "rng_count = " << rng_count << endl;
	}
        
	PDBConstraints::~PDBConstraints()
	{
		delete canonical;
		canonical = 0;
		if (!borrowed_heuristics) {
			for (size_t i = 0; i < heuristics.size(); ++i) {
				delete heuristics[i];
			}
		}
	}

	void
	PDBConstraints::initialize_constraints(LPConstraintCollection &constraint_collection, vector<bool> &filter)
	{
		// Filter (if provided)
		if (!filter.empty()) {
			vector<PDBHeuristic*> filtered_heuristics;
			filtered_heuristics.reserve(heuristics.size());
			for (size_t i = 0; i < heuristics.size(); ++i) {
				if (filter[i]) {
					filtered_heuristics.push_back(heuristics[i]);
				} else if (!borrowed_heuristics) {
					delete heuristics[i]; // cleanup
				}
			}
			//cout << "Removing PDBs: #PDBs from " << heuristics.size();
			heuristics.swap(filtered_heuristics);
			//cout << " to " << heuristics.size() << endl;
		}
                //cout << "Number of constraints: " << heuristics.size() << endl;
		// continue as before...

		vector<LPConstraint> constraints(heuristics.size());
		for (size_t i = 0; i < heuristics.size(); ++i) {
			PDBHeuristic *h = heuristics[i];
			const std::vector<bool> &rel_ops = h->get_relevant_operators();
			for (size_t op_id = 0; op_id < g_operators.size(); ++op_id) {
				if (rel_ops[op_id]) {
					double op_cost = get_adjusted_action_cost(g_operators[op_id], cost_type);
					constraints[i].insert(op_id, op_cost);
				}
			}
		}
		constraint_offset = constraint_collection.add_constraints(constraints);
	}

	bool
	PDBConstraints::reach_state(const State &parent_state, const Operator &op,
		const State &state)
	{
		bool h_dirty = false;
		for (size_t i = 0; i < heuristics.size(); ++i) {
			if (heuristics[i]->reach_state(parent_state, op, state)) {
				h_dirty = true;
			}
		}
		return h_dirty;
	}

	bool
	PDBConstraints::update_constraints(const State &state, OperatorCountLP &lp)
	{
		for (size_t i = 0; i < heuristics.size(); ++i) {
			int constraint_id = constraint_offset + i;
			PDBHeuristic *h = heuristics[i];
			h->evaluate(state);
			if (h->is_dead_end()) {
				return true;
			}
			int h_val = h->get_heuristic();
			lp.set_permanent_constraint_lower_bound(constraint_id, h_val);
		}
		return false;
	}

	static ConstraintGenerator *
	_parse(OptionParser &parser)
	{
		PatternGenerationHaslum::create_options(parser);
		parser.add_option<float>("systematic",
			"Systematically generate all patterns with up to n variables instead of using PatternGenerationHaslum.",
			"0.0");
		parser.add_option<bool>("prune_irrelevant_patterns",
			"Prune irrelevant patterns before building the LP.",
			"true");
		parser.add_option<int>("size_limit",
			"Maximum number of abstract states in all generated PDBs combines.",
			"0");
		parser.add_option<int>("num_pdb_limit",
			"Maximum number of PDBs to generate.",
			"0");
		parser.add_option<int>("random_seed",
			"Random seed to use for pattern ordering",
			"2011");
                parser.add_option<float>("pho_eval",
                        "Limit on maximum evaluation time per state on pho heuristic",
                        "0");
		Heuristic::add_options_to_parser(parser);
		Options opts = parser.parse();
		if (parser.help_mode())
			return 0;
		PatternGenerationHaslum::sanity_check_options(parser, opts);

		if (parser.dry_run())
			return 0;
		return new PDBConstraints(opts);
	}

	static Plugin<ConstraintGenerator> _plugin("pdb_constraints", _parse);
}
