#include "bootstrap.h"

#include <bits/shared_ptr_base.h>
#include <cassert>
#include <csignal>
#include <iostream>
#include <iterator>
#include <memory>
#include <sstream>
#include <string>
#include <utility>

#include "stratified_sampling.h"

#include "../abstract_task.h"
#include "../evaluation_context.h"
#include "../ff_heuristic.h"
#include "../globals.h"
#include "../option_parser.h"
#include "../rng.h"
#include "../task_proxy.h"
#include "../task_tools.h"
#include "../timer.h"

#define MIN_RW_LENGTH 5

using namespace std;

Bootstrap::Bootstrap(
		const std::shared_ptr<AbstractTask> task,
		const Options opts)
: task(task),
  task_proxy(*task),
  successor_generator(task),
  set_back_to_initial_state(g_initial_state_data),
  number_of_random_walks(opts.get<int>("number_of_rw")),
  random_walk_length(opts.get<int>("rw_length")),
  ann_topology(opts.get<string>("ann_topology")),
  initial_heuristic_name(opts.get<string>("initial_heuristic")),
  method(opts.get<string>("method")),
  eta(opts.get<double>("eta")),
  epochs(opts.get<int>("epochs")),
  generator(r_device()),
  hw(new HeuristicWrapper(initial_heuristic_name))

{
	verify_no_axioms();
	Timer timer;
	initialize();
	initialize_neural_network();

	run_bootstrap();

	cout << "Bootstrap initialization time: " << timer << endl;
}

int Bootstrap::call_biss(State s) {
	StratifiedSampler ss(task, hw);
//	ss.set_root(create_start());
	ss.set_root(s);
	for(int i = 0; i < 10; i++){
		ss.add_goal(create_goal());
	}
	return ss.sample();
}

void Bootstrap::run_bootstrap() {
	/*
	 * Use method to check with mode for bootstrap should be executed. There are 4 modes.
	 * 1) Use search to find paths from states to goal. Add all states und the paths to goal with their distance to the learning set. Need to specify a heuristic function. (search)
	 * 2) Use random walk from goal states and add their distance to the goal to the learning set. (No heuristic requirred. (walk)
	 * 3) Sample states near the start and the goal, and add these states with their corresponding heuristic value to the training set. Need to specify a heuristic function. (sample)
	 * 4) Sample states near the start and the goal, and add these states with their corresponding heuristic value to the training set. Need to specify a heuristic function. (sample_random)
	 * 5) Use a cost predictor to estimate the distance from a state to goal. Add the states with the predicted distance to the training set. Need to specify a heuristic function. (predict)
	 */

	if(method.compare("search") == 0){
		create_instances();
		suppress_output();
		solve_instances();
		restore_output();
	}else if(method.compare("walk") == 0){
		initialize_training_set();
	}else if(method.compare("sample") == 0){
		for (int i = 0; i < number_of_random_walks; i++) {
			sample_state_space(create_goal());
		}
		for (int i = 0; i < number_of_random_walks; i++) {
			sample_state_space(create_start());
		}
	}else if(method.compare("sample_random") == 0){
		for (int i = 0; i < (number_of_random_walks * random_walk_length); i++) {
			sample_state_space_random();
		}
	}else if(method.compare("estimate") == 0){
		initialize_training_set();
		for(size_t i = 0; i < learning_set.size(); i++){
//			cout << "RW_LENGTH: " << learning_set[i].second << endl;
			int result = call_biss(learning_set[i].first);
//			cout << result << endl;
			if(result >= 0){
				learning_set.replace_value(i, result);
			}
		}
	}else{
		cout << "Specify a method" << endl;
	}
	train_learner();
}

void Bootstrap::initialize(){
}

void Bootstrap::sample_state_space(const State &start_state){
	StateValuePairVector state_vector;
	perform_random_walk(start_state, generate_rw_length(), state_vector);
	for(StateValuePair p : state_vector){
		int h = hw->get_heuristic_value(p.first);
		learning_set.add_entry(p.first, h);
	}
}

void Bootstrap::sample_state_space_random(){
	State s = create_random_state();
	int h = hw->get_heuristic_value(s);
	learning_set.add_entry(s, h);
}


void Bootstrap::initialize_neural_network(){
	NetTopology topology;
	topology.push_back(FeatureExtractor::get_num_features(task_proxy.get_initial_state()));

	istringstream iss(ann_topology);
	stringstream ss;
	double d;
	for(string s;getline(iss, s, '|');){
		ss.clear();
		ss << s;
		ss >> d;
		topology.push_back(d);
	}

	topology.push_back(1);
	nn = new NeuralNet(topology);
}

void Bootstrap::initialize_training_set() {
	StateValuePairVector state_vector;
	for (int i = 0; i < number_of_random_walks; i++) {
		const State goal = create_goal();
		perform_random_walk(goal, generate_rw_length(), state_vector);
		for (StateValuePair p : state_vector) {
			learning_set.add_entry(p);
		}
	}
}

void Bootstrap::create_instances(){
	StateValuePairVector state_vector;
	for (int i = 0; i < number_of_random_walks; i++) {
		const State goal = create_goal();
		perform_random_walk(goal, generate_rw_length(), state_vector);
		learning_set.add_entry(state_vector.back());
	}
//	initialize_training_set();

	int number_of_random_states = 50;
	for(int i = 0; i < number_of_random_states; i++){
		instances.push_back(create_random_state());
	}
	for(size_t i = 0; i < learning_set.size(); i++){
		instances.push_back(learning_set[i].first);
	}
	learning_set.clear();

	cout << "Size of instances: " << instances.size() << endl;
}

State Bootstrap::create_goal(){
	vector<int> goal_vector;

	for(size_t i = 0; i < g_variable_domain.size(); i++){
		int range = g_variable_domain[i];
		uniform_int_distribution<int> dist(0, range-1);
		int t = dist(generator);
		goal_vector.push_back(t);
	}

	//set variables to goal conditions
	GoalsProxy goals = task_proxy.get_goals();
	for(size_t i = 0;i < goals.size(); i++){
		goal_vector[goals[i].get_variable().get_id()] = goals[i].get_value();
	}

	State test(*task, move(goal_vector));
	return test;
}

State Bootstrap::create_start(){
	return task_proxy.convert_global_state(g_initial_state());
}

State Bootstrap::create_random_state(){
	vector<int> state_vector;

	for(size_t i = 0; i < g_variable_domain.size(); i++){
		int range = g_variable_domain[i];
		uniform_int_distribution<int> dist(0, range-1);
		int t = dist(generator);
		state_vector.push_back(t);
	}

	State res(*task, move(state_vector));
	return res;
}

void Bootstrap::extract_states_on_path(vector<State>::iterator &instance_iterator) {
	/* Check if search was successful. If solution found, add all
	 the instances on the solution path to the learning set.
	 */
	if (engine->found_solution()) {
		//Add plan to learning set. Remove instance from instances.
		SearchEngine::Plan plan = engine->get_plan();
		State current_state(task_proxy.get_initial_state());
		int distance_to_goal = calculate_plan_cost(plan);

		// Iterate through plan to get all the states on the path to goal.
		vector<OperatorProxy> applicable_ops;
		for (auto op : plan) {
			applicable_ops.clear();
			successor_generator.generate_applicable_ops(current_state, applicable_ops);
			for (OperatorProxy applicable_op : applicable_ops) {
				if (applicable_op.get_global_operator() == op) {
					distance_to_goal -= applicable_op.get_cost();
					assert(distance_to_goal >= 0);
					current_state = current_state.get_successor(applicable_op);
					learning_set.add_entry(make_pair(current_state, distance_to_goal));
					break;
				}
			}
		}
		assert(distance_to_goal == 0);
		instance_iterator = instances.erase(instance_iterator);
	} else {
		++instance_iterator;
	}
}

void Bootstrap::solve_instances(){
	/* Run a search with given time limit t_max, on the given start state.
	 */
	vector<State>::iterator instance_iterator;
	for(instance_iterator = instances.begin();instance_iterator != instances.end();){
		State current_instance = *instance_iterator;
		set_search_engine(current_instance);

		engine->search();
		/* Check if search was successful. If solution found, add all
	       the instances on the solution path to the learning set.
		 */
		extract_states_on_path(instance_iterator);
	}

	// The bootstrap procedure is over here, swap the registries back to use the initial g_state_registry
	delete g_state_registry;
	g_initial_state_data = set_back_to_initial_state;
	g_swap_state_registries();
}

void Bootstrap::train_learner(){
	vector<int> features;
	for(size_t i = 0; i < learning_set.size(); i++){
		features.clear();
		State state = learning_set[i].first;
		FeatureExtractor::extract_features(state, features);
		vector<double> inputs(features.begin(), features.end());
		nn->add_training_set_entry(inputs, learning_set[i].second);
	}

	nn->dump_training_set();
	Neuron::eta = eta;
	nn->set_epochs(epochs);
	nn->train_network();
}

/*
 * Call this every time the time limit is adjusted.
 */
void Bootstrap::set_search_engine(const State &start){
	if(first_run){
		/* Just swap the state registries for the first time, to save the initial state of the g_state_registry.
		 * g_state_registry is used in the final search, and we do not to use a messed up registry for this.
		 * If the initial g_state_registry is already swapped, we just delete the actual g_state_registry in initialize a new one.
		 */
		g_swap_state_registries();
		first_run = false;
	}
	delete g_state_registry;
	g_state_registry = new StateRegistry;

	g_initial_state_data = FeatureExtractor::extract_features(start);
	delete engine;
	stringstream parse_option;
	parse_option << "astar(";
	parse_option << initial_heuristic_name;
	parse_option << ", max_time=" << t_max;
	parse_option << ")";
	OptionParser parser(parse_option.str(), false);
	engine = parser.start_parsing<SearchEngine *>();
}

void Bootstrap::perform_random_walk(const State &in_state, int rw_length, StateValuePairVector &rw_set){
	rw_set.clear();
	State current_state(in_state);
	vector<OperatorProxy> applicable_ops;


	int rw_cost = 0;
	for (int j = 0; j < rw_length; ++j) {
		rw_set.push_back(make_pair(current_state, rw_cost));
		applicable_ops.clear();

		successor_generator.generate_applicable_ops(current_state,
				applicable_ops);
		// If there are no applicable operators, do not walk further.raise
		if (applicable_ops.empty()) {
			break;
		} else {
			const OperatorProxy &random_op = *g_rng.choose(applicable_ops);
			assert(is_applicable(random_op, current_state));
			current_state = current_state.get_successor(random_op);
			rw_cost += random_op.get_cost();
		}
	}
}

StateValuePairVector Bootstrap::perform_random_walk(const State &in_state, int rw_length){
	StateValuePairVector rw_set;
	perform_random_walk(in_state, rw_length, rw_set);
	return rw_set;
}

int Bootstrap::generate_rw_length(){
	uniform_int_distribution<int> dist(MIN_RW_LENGTH, MIN_RW_LENGTH + random_walk_length);
	return dist(generator);
}

int Bootstrap::get_heuristic_from_neural_net(const vector<int> &features){
	vector<double> d_features(features.begin(), features.end());
	nn->feed_forward(d_features);
	vector<double> results;
	nn->get_results(results);
	return results[0];
}

void Bootstrap::suppress_output(){
	sup_out.open("/dev/null");
	cout.rdbuf(sup_out.rdbuf());
}

void Bootstrap::restore_output(){
	cout.rdbuf(cout_backup);
}
