#include "posthoc_optimization_heuristic.h"

#include "lp_constraint_collection.h"

#include "../globals.h"
#include "../rng.h"
#include "../lp_solver_interface.h"
#include "../plugin.h"
#include "../timer.h"
#include "../utilities.h"
#include "../max_heuristic.h"

#include "../state.h"
#include "../state_registry.h"
#include "../successor_generator.h"

using namespace std;

namespace pho {

	PosthocOptimizationHeuristic::PosthocOptimizationHeuristic(const Options &opts)
	: Heuristic(opts),
	merge_lp_variables(opts.get<bool>("merge_lp_variables")),
	constraint_generators(opts.get_list<ConstraintGenerator *>("constraint_generators")),
	lp_solver_type(LPSolverType(opts.get_enum("lpsolver"))),
	lp(0),
        pruning_samples(0),
        prune_num_constraints(0),
	opts(opts)
	{
            if(opts.contains("pruning_samples"))
                pruning_samples = opts.get<int>("pruning_samples");
            if(opts.contains("prune_num_constraints"))
                prune_num_constraints = opts.get<int>("prune_num_constraints");
            
            /*if((pruning_samples > 0)^(prune_num_constraints > 0)){
                cout << "[WARNING] pruning_samples and prune_num_constraints must be used together!" << endl;
            }*/
	}

	PosthocOptimizationHeuristic::~PosthocOptimizationHeuristic()
	{
		delete lp;
	}

	void PosthocOptimizationHeuristic::print_statistics() const
	{
		lp->print_statistics("after search");
	}

        bool pair_sorter(pair<int,int> i, pair<int,int> j) {
            return (i.second > j.second);
        }
        
        /*bool filter_num(int i, vector<pair<int,int> > &counts){
            return (i >= counts.size()/2);
        }
        
        bool filter_occ(int i, vector<pair<int,int> > &counts){
            return (counts[i].second > 0);
        }*/
        
	void PosthocOptimizationHeuristic::initialize()
	{
		assert(!lp);
                Timer init_timer;
                
		vector<bool> filter_constraints;
		
		int MIN_OCCURRENCE = 1;
		if (pruning_samples > 0) {
			Timer prune_timer;
			
			// calculate average operator costs
			double average_operator_cost = 0;
			for (size_t i = 0; i < g_operators.size(); ++i) {
				average_operator_cost += get_adjusted_action_cost(g_operators[i], cost_type);
			}
			average_operator_cost /= g_operators.size();
			cout << "Average operator cost: " << average_operator_cost << endl;
			
			Options opts(this->opts);
			opts.set<int>("pruning_samples", 0);
			cout << "Creating temporary pho-heuristic" << endl; // for easier reading of output
			PosthocOptimizationHeuristic *heuristic = new PosthocOptimizationHeuristic(opts);
			heuristic->initialize();
			
			Options sample_opts(this->opts);
			HSPMaxHeuristic *sample_heuristic = new HSPMaxHeuristic(sample_opts);
			
			int num_constraints = heuristic->lp->get_num_constraints();
			vector<pair<int,int> > binding_count;
                        binding_count.reserve(num_constraints);
                        for (size_t constraint = 0; constraint < num_constraints; ++constraint) {
                            binding_count.push_back(make_pair(constraint,0));
                        }
			vector<State> samples;
			cout << "Sampling " << pruning_samples << " states" << endl;
			sample_states(samples, pruning_samples, average_operator_cost, sample_heuristic);
			for (size_t sample = 0; sample < pruning_samples; ++sample) {
				heuristic->evaluate(samples[sample]);
				heuristic->lp->update_binding_count(binding_count);
			}
                        
                        cout << "Starting sorting...";
                        std::sort(binding_count.begin(), binding_count.end(), pair_sorter);
			cout << "done" << endl;
                        
			int num_unused_constraints = 0;
			filter_constraints.resize(num_constraints, false);
                        //bool (*func)(int, vector<pair<int,int> >&) = &filter_num;
			for (size_t constraint = 0; constraint < num_constraints; ++constraint) {
				if((prune_num_constraints == 0 && binding_count[constraint].second > 0)||(prune_num_constraints > 0 && constraint < prune_num_constraints)){
					filter_constraints[binding_count[constraint].first] = true; // keep
				} else {
				    ++num_unused_constraints;
				}
			}
			delete heuristic;
			delete sample_heuristic;
                        cout << "Number of constraints: " << (num_constraints - num_unused_constraints) << endl;
			//cout << "Removed " << num_unused_constraints << " of " << num_constraints << " constraints from LP" << endl;
			cout << "Pruning time: " << prune_timer << endl;
		}
		
		Timer compile_timer;
		LPConstraintCollection constraint_collection;
		for (size_t i = 0; i < g_operators.size(); ++i) {
                    int op_cost = get_adjusted_action_cost(g_operators[i], cost_type);
                    constraint_collection.add_variable(
                        LPVariable(0, numeric_limits<double>::infinity(), op_cost));
		}
		for (size_t i = 0; i < constraint_generators.size(); ++i) {
                    constraint_generators[i]->initialize_constraints(constraint_collection, filter_constraints);
		}
		lp = new OperatorCountLP(constraint_collection, merge_lp_variables,
			lp_solver_type, cost_type);
		cout << "LP creation time: " << compile_timer << endl;
		heuristic = 0;
                cout << "Initialization time: " << init_timer << endl;
	}

	void PosthocOptimizationHeuristic::sample_states(vector<State> &samples, int num_samples, double average_operator_cost, Heuristic *heuristic)
	{
		const State &initial_state = g_initial_state();
		heuristic->evaluate(initial_state);
		assert(!heuristic->is_dead_end());

		int h = heuristic->get_heuristic();
		int n;
		if (h == 0) {
			n = 10;
		} else {
			// Convert heuristic value into an approximate number of actions
			// (does nothing on unit-cost problems).
			// average_operator_cost cannot equal 0, as in this case, all operators
			// must have costs of 0 and in this case the if-clause triggers.
			int solution_steps_estimate = int((h / average_operator_cost) + 0.5);
			n = 4 * solution_steps_estimate;
		}
		double p = 0.5;
		// The expected walk length is np = 2 * estimated number of solution steps.
		// (We multiply by 2 because the heuristic is underestimating.)

		int count_samples = 0;
		samples.reserve(num_samples);

		State current_state(initial_state);
		
		int rng_count = 0;
		cout << "num_samples: " << num_samples << endl;
		cout << "n: " << n << endl;
		cout << "average operator cost: " << average_operator_cost << endl;

		for (; count_samples < num_samples; ++count_samples) {
			// calculate length of random walk accoring to a binomial distribution
			int length = 0;
			for (int j = 0; j < n; ++j) {
				double random = g_rng(); // [0..1)
				rng_count++;
				if (random < p)
					++length;
			}

			// random walk of length length
			State current_state(initial_state);
			for (int j = 0; j < length; ++j) {
				vector<const Operator *> applicable_ops;
				g_successor_generator->generate_applicable_ops(current_state, applicable_ops);
				// if there are no applicable operators --> do not walk further
				if (applicable_ops.empty()) {
					break;
				} else {
					int random = g_rng.next(applicable_ops.size()); // [0..applicable_os.size())
					rng_count++;
					assert(applicable_ops[random]->is_applicable(current_state));
					// TODO for now, we only generate registered successors. This is a temporary state that
					// should should not necessarily be registered in the global registry: see issue386.
					current_state = g_state_registry->get_successor_state(current_state, *applicable_ops[random]);
					// if current state is a dead end, then restart with initial state
					heuristic->evaluate(current_state); // TODO: only evaluate if dead end or not
					if (heuristic->is_dead_end())
						current_state = initial_state;
				}
			}
			// last state of the random walk is used as sample
			samples.push_back(current_state);
		}
		
		cout << "rng_count = " << rng_count << endl;

	}

	bool PosthocOptimizationHeuristic::reach_state(const State &parent_state,
		const Operator &op, const State &state)
	{
		bool h_dirty = false;
		for (size_t i = 0; i < constraint_generators.size(); ++i) {
			if (constraint_generators[i]->reach_state(parent_state, op, state)) {
				h_dirty = true;
			}
		}
		return h_dirty;
	}

	int PosthocOptimizationHeuristic::compute_heuristic(const State &state)
	{
		assert(lp);
		// Make sure there are no leftover temporary constraints
		lp->remove_temporary_constraints();
		for (size_t i = 0; i < constraint_generators.size(); ++i) {
			bool dead_end = constraint_generators[i]->update_constraints(state, *lp);
			if (dead_end) {
				return DEAD_END;
			}
		}
		int result;
		lp->solve();
		if (lp->has_feasible_solution()) {
			result = lp->get_heuristic_value();
		} else {
			result = DEAD_END;
		}
		lp->remove_temporary_constraints();
		return result;
	}

	static Heuristic *_parse(OptionParser &parser)
	{
		parser.add_list_option<ConstraintGenerator *>(
			"constraint_generators",
			"methods that generate constraints over LP variables "
			"representing the number of operator applications");
		parser.add_option<bool>(
			"merge_lp_variables",
			"merge operators in the same equivalence class into one variable",
			"false");
		parser.add_option<int>(
			"pruning_samples",
			"number of samples to use for pruning. 0 = no pruning (default)",
			"0");
                parser.add_option<int>(
			"prune_num_constraints",
			"number of constraints after pruning. 0 = no pruning (default)",
			"0");
		add_lp_solver_option_to_parser(parser);
		Heuristic::add_options_to_parser(parser);
		Options opts = parser.parse();
		if (parser.help_mode())
			return 0;
		opts.verify_list_non_empty<ConstraintGenerator *>("constraint_generators");
		if (parser.dry_run())
			return 0;
		return new PosthocOptimizationHeuristic(opts);
	}

	static Plugin<Heuristic> _plugin("pho", _parse);
}
