#include "stratified_sampling.h"

#include <iostream>
#include <string>

#include "../evaluation_context.h"
#include "../heuristic.h"
#include "../option_parser.h"

using namespace std;

typedef vector<int> Type;

StratifiedSampler::StratifiedSampler(const std::shared_ptr<AbstractTask> task, const std::shared_ptr<HeuristicWrapper> hw)
: task(task),
  task_proxy(*task),
  successor_generator(task),
  hw(hw)
{
	State tmp(task_proxy.convert_global_state(g_state_registry->get_initial_state()));
	set_root(tmp);
	//	sample();
}

void StratifiedSampler::estimate_max_distance(int dist){
	max_distance = dist;
}

void StratifiedSampler::set_root(const State &g){
	State tmp(g);
	vector<StateWeightPair> next_layer_tmp;
	next_layer_tmp.push_back(StateWeightPair(tmp, 1.0));
	A.push_back(next_layer_tmp);
}

void StratifiedSampler::add_goal(const State &g){
	State tmp(g);
	vector<StateWeightPair> next_layer_tmp;
	next_layer_tmp.push_back(StateWeightPair(tmp, 1.0));
	Frontier tmp_b;
	tmp_b.push_back(next_layer_tmp);
	B.push_back(tmp_b);
}

void StratifiedSampler::expand_all_children(const StateWeightPair &p, vector<StateWeightPair> &next_layer_tmp){
	State g(p.first);
	//next_layer_tmp.clear();
	vector<OperatorProxy> ops;
	ops.clear();
	successor_generator.generate_applicable_ops(g, ops);
	if(ops.empty()){
		cout << "EMPTY" << endl;
	}else{
		for(const OperatorProxy go : ops){
			State tmp = g.get_successor(go);
			vector<int> t = create_type(tmp);
			//if(has_same_type(tmp, next_layer)){
			pair<bool, size_t> result = same_type_in(t, next_layer_tmp);
			if(result.first){
				//cout << "There are states with the same type" << endl;
				double w_prime = next_layer_tmp[result.second].second + p.second;

				uniform_real_distribution<double> distribution(0.0,1.0);

				if(distribution(generator) <= p.second / w_prime){
					next_layer_tmp[result.second] = StateWeightPair(tmp, w_prime);
				}
			}else{
				next_layer_tmp.push_back(StateWeightPair(tmp, p.second));
			}
		}
	}

}

void StratifiedSampler::expand_frontier(Frontier &F){
	vector<StateWeightPair> next_layer_tmp;
	for(StateWeightPair parent : F.back()){
		expand_all_children(parent, next_layer_tmp);
	}
	F.push_back(next_layer_tmp);
}


void StratifiedSampler::reset_frontier(Frontier &F, int i){
	for(size_t k = i; k < F.size(); k++){
//		delete &F[k];
		F.erase(F.begin() + k);
	}
}

int StratifiedSampler::compute_k(int m){
	double gam = 0.5;
	return max((int)(gam * m), 1);
}

int StratifiedSampler::advance_only_A(int n, int m) {
	for(Frontier tmp_b : B){
		if (check_overlap_in_layers(n, m, tmp_b)) {
			int k = compute_k(m);
			int ran = 1;
			for (ran = 1; ran <= k; ran++) {
				expand_frontier(A);
				if (!check_overlap_in_layers(n + ran, m - ran, tmp_b)) {
					break;
				}
			}
			if (ran == k) {
//				cout << "SOLUTION COST: " << n + m << endl;
				return n + m;
			} else {
				reset_frontier(A, n);
			}
		}
	}
	return -1;
}

int StratifiedSampler::sample(){
	int n = 0;
	int m = 0;
	for(int i = 0; i < max_distance; i++){		//TODO: change stopping criteria
		expand_frontier(A);
		n++;
		if(advance_only_A(n, m) >= 0){
			return n+m;
		}

		for(Frontier &tmp_b : B){
			expand_frontier(tmp_b);
		}
		m++;
		if(advance_only_A(n, m) >= 0){
			return n+m;
		}

	}
	return -1;
}

bool StratifiedSampler::check_overlap_in_layers(int start_a, int start_b, Frontier &tmp_b){
	size_t a_check = start_a;
	size_t b_check = start_b;
	if(a_check < A.size() && b_check < tmp_b.size()){
		for(auto node_in_A : A[start_a]){
			Type type_a = create_type(node_in_A.first);
			bool res = same_type_in(type_a, tmp_b[start_b]).first;
			if(res){
				return true;
			}
		}
		return false;
	}else{
		return false;
	}
}

pair<bool, size_t> StratifiedSampler::same_type_in(const vector<int> &type, const vector<StateWeightPair> &layer){
	for(size_t i = 0; i < layer.size(); i++){
		StateWeightPair p = layer[i];
		vector<int> compare = create_type(p.first);
		bool equal = true;
		for(size_t k = 0; k < type.size(); k++){
			if(type[k] != compare[k]){
				equal = false;
				break;
			}
		}
		if(equal){
			return make_pair(true, i);
		}
	}
	return make_pair(false, 0);
}

vector<int> StratifiedSampler::create_type(const State &g){
	int max_heuristic_value = 100;
	vector<int> type(max_heuristic_value + 1);		//fist position is the heuristic value of the current state, the following positions are the number of childrens with a heuristic value in [0..max_heuristc_value]
	vector<OperatorProxy> ops;
	successor_generator.generate_applicable_ops(g, ops);

	type[0] = hw->get_heuristic_value(g);

	for(OperatorProxy go : ops){
		State child = g.get_successor(go);
		int h = hw->get_heuristic_value(child);
		if(h <= max_heuristic_value){
			type[h+1]++;
		}
	}
	return type;
}

