#include "sdd_search.h"

#include "../search_engines/search_common.h"

#include "../evaluation_context.h"
#include "../globals.h"
#include "../heuristic.h"
#include "../option_parser.h"
#include "../plugin.h"
#include "../pruning_method.h"
#include "../successor_generator.h"

#include "../open_lists/open_list_factory.h"

#include <cassert>
#include <cstdlib>
#include <memory>
#include <set>
#include <map>
#include <vector>
#include <string.h>
#include <algorithm>
 

extern "C" {
	#include "sddapi.h"
}

using namespace std;

namespace sdd_search {
SddSearch::SddSearch(const Options &opts)
    : SearchEngine(opts),
      _orderType(opts.get<std::string>("ordertype")),
      _orderTree(opts.get<std::string>("order")),
      _verbose(opts.get<bool>("verbose")) {
}

void SddSearch::initialize() {
	
	_domains = g_variable_domain;
	int vars = _domains.size();
	
	_varCount = 0;
	for (int i = 0; i < vars; i++) {
		_varCount += _domains[i];
	}
	_varCount *= 2; //primed variables
	
	if (_orderTree.compare("right") == 0) {
		_manager = sdd_manager_new(sdd_vtree_new(_varCount, "right"));
		cout << "Building right-linear vtree. (effectively BDD)" << endl;
	} else if (_orderTree.compare("left") == 0) {
		_manager = sdd_manager_new(sdd_vtree_new(_varCount, "left"));
		cout << "Building left-linear vtree." << endl;
	} else if (_orderTree.compare("vertical") == 0) {
		_manager = sdd_manager_new(sdd_vtree_new(_varCount, "vertical"));
		cout << "Building vertical vtree." << endl;
	} else {
		_manager = sdd_manager_new(sdd_vtree_new(_varCount, "balanced"));
		cout << "Building balanced vtree." << endl;
	}
	sdd_manager_auto_gc_and_minimize_on(_manager);
	
	//set variable order and renaming arrays
	setOrder(_orderType);
	
	// now make a list of T_o for all actions
	
	int size_op = g_operators.size();
	for (int i = 0; i < size_op; i++) {
		
		map<int,int> pre;
		std::vector<GlobalCondition> gpre = g_operators[i].get_pre();
		int size_pre = gpre.size();
		for (int j = 0; j < size_pre; j++) {
			pre[gpre[j].var] = gpre[j].val;
		}
		
		map<int,int> eff;
		std::vector<GlobalEffect> geff = g_operators[i].get_eff();
		int size_eff = geff.size();
		for (int j = 0; j < size_eff; j++) {
			eff[geff[j].var] = geff[j].val;
		}
		
		int cost = g_operators[i].get_cost();
		
		SddNode* transition = generateTransition(pre, eff);
		_operators[cost].push_back(transition);
		_operatorCorrespondence[transition] = &(g_operators[i]);
		
	}
}
	
/**
  * Sets the variable order in _order and _orderPrime that will henceforth be adhered to.
  * Do not change mid-program, as that will lead to misinterpretation of SDDs generated prior.
  * Only call upon initialization!
  * also sets renaming arrays _renameToPrime and _renameFromPrime
  */
void SddSearch::setOrder(const std::string& mode) { //currently only one order supported
	
	if (_verbose) cout << "start set order" << endl;
	
	int vars = _domains.size();
	_order.resize(vars);
	_orderPrime.resize(vars);
	
	
	int i = 1;
	
	if (mode.compare("primedlast") == 0) {// primed last A_F , A_T , B_F , B_T , A_F' , A_T' , B_F' , B_T'    (correspondence is a nightmare with this one, worst one so far ^^)
		cout << "Using primed last order" << endl;
		
		for (int v = 0; v < vars; v++) {
			_order[v].resize(_domains[v]);
			_orderPrime[v].resize(_domains[v]);
			for (int d = 0; d < _domains[v]; d++) {
				if (_verbose) cout << v << "_" << d << " -> " << i << endl;
				_order[v][d] = i++;
			}
		}
		for (int v = 0; v < vars; v++) {
			_order[v].resize(_domains[v]);
			_orderPrime[v].resize(_domains[v]);
			for (int d = 0; d < _domains[v]; d++) {
				if (_verbose) cout << v << "_" << d << "' -> " << i << endl;
				_orderPrime[v][d] = i++;
			}
		}
		
	} else if (mode.compare("densedomains") == 0) {// dense domains A_F , A_T , A_F' , A_T' , B_F , B_T , B_F' , B_T'
		cout << "Using dense domains order" << endl;
		
		for (int v = 0; v < vars; v++) {
			_order[v].resize(_domains[v]);
			_orderPrime[v].resize(_domains[v]);
			for (int d = 0; d < _domains[v]; d++) {
				if (_verbose) cout << v << "_" << d << " -> " << i << endl;
				_order[v][d] = i++;
			}
			for (int d = 0; d < _domains[v]; d++) {
				if (_verbose) cout << v << "_" << d << "' -> " << i << endl;
				_orderPrime[v][d] = i++;
			}
		}
		
	} else {
		if (mode.compare("simple") == 0) {// simple order A_F , A_F' , A_T , A_T', B_F , B_F' , B_T , B_T'
			cout << "Using simple variable order" << endl;
		} else {
			cout << "Order not recognized; using simple variable order" << endl;
		}

		for (int v = 0; v < vars; v++) {
			_order[v].resize(_domains[v]);
			_orderPrime[v].resize(_domains[v]);
			for (int d = 0; d < _domains[v]; d++) {
				if (_verbose) cout << v << "_" << d << " -> " << i << endl;
				_order[v][d] = i++;
				if (_verbose) cout << v << "_" << d << "' -> " << i << endl;
				_orderPrime[v][d] = i++;
			}
		}
		
	}
	
	
	if (_verbose) cout << "renaming arrays" << endl;
	//make renaming arrays
	_renameToPrime = new SddLiteral[_varCount+1];
	_renameFromPrime = new SddLiteral[_varCount+1];
	
	for (i = 0; i <= _varCount; i++) {
		_renameToPrime[i] = 0;
		_renameFromPrime[i] = 0;
	}
	for (int v = 0; v < vars; v++) {
		for (int d = 0; d < _domains[v]; d++) {
			_renameToPrime[_order[v][d]] = _orderPrime[v][d];
			_renameFromPrime[_orderPrime[v][d]] = _order[v][d];
		}
	}
	
	if (_verbose) cout << "end set order" << endl;
}
	
/**
  * Makes an Sdd representing a transition corresponding to a specific action.
  *
  * pre maps variables by index to a value by index in its domain that it must have for the action to take effect. (e.g. loc=C => 0->2 (Variable 0 to value 2))
  * pre maps variables by index to a value by index in its domain that it will take after the action is applied.
  * vars is the number of unique variables (in SAS+ problem, not _varCount)
  * domainSize maps these variables to the size of their domains.
  */
SddNode* SddSearch::generateTransition(map<int,int> &pre, map<int,int> &eff) {
	
	int vars = _domains.size();
	
	if (_verbose) cout << ">>>>>  Begin Generating Transition:" << endl; //DEBUG
	
	SddNode* T = sdd_manager_true(_manager);
	SddNode* tmp;
	
	if (_verbose) cout << "  preconditions" << endl; //DEBUG
	// preconditions
	for (int v = 0; v < vars; v++) {
		if (pre.count(v) > 0) {
			for (int d = 0; d < _domains[v]; d++) {
				if (_verbose) cout << "        Address of T =  " << T << endl; //DEBUG
				if (pre[v] == d) {
					if (_verbose) cout << "    Variable " << v << "_" << d << " (" << _order[v][d] << ") must be true" << endl; //DEBUG
					T = sdd_conjoin(tmp = T, sdd_manager_literal(_order[v][d],_manager),_manager);
					sdd_ref(T,_manager);
					sdd_deref(tmp,_manager);
				} else {
					if (_verbose) cout << "    Variable " << v << "_" << d << " (" << _order[v][d] << ") must be false" << endl; //DEBUG
					T = sdd_conjoin(tmp = T, sdd_manager_literal(-(_order[v][d]),_manager),_manager);
					sdd_ref(T,_manager);
					sdd_deref(tmp,_manager);
					//should not actually be necessary. When locB is true, locA and locC will automatically be false, right?
					// TODO try this out once it somewhat works
				}
			}
		}
	}
	
	if (_verbose) cout << "  effects" << endl; //DEBUG
	// effects
	for (int v = 0; v < vars; v++) {
		if (eff.count(v) > 0) {
			for (int d = 0; d < _domains[v]; d++) {
				if (eff[v] == d) {
					if (_verbose) cout << "    Variable " << v << "_" << d << "' (" << _orderPrime[v][d] << ") will be true" << endl; //DEBUG
					T = sdd_conjoin(tmp = T, sdd_manager_literal(_orderPrime[v][d],_manager),_manager);
					sdd_ref(T,_manager);
					sdd_deref(tmp,_manager);
				} else {
					if (_verbose) cout << "    Variable " << v << "_" << d << "' (" << _orderPrime[v][d] << ") will be false" << endl; //DEBUG
					T = sdd_conjoin(tmp = T, sdd_manager_literal(-(_orderPrime[v][d]),_manager),_manager);
					sdd_ref(T,_manager);
					sdd_deref(tmp,_manager);
				}
			}
		}
	}
	
	if (_verbose) cout << "  correspondence" << endl; //DEBUG
	// correspondence
	for (int v = 0; v < vars; v++) {
		if (eff.count(v) <= 0) {
			for (int d = 0; d < _domains[v]; d++) {
				if (_verbose) cout << "Variable " << v << "_" << d << " (" << _order[v][d] << ") will correspond with " << v << "_" << d << "' (" << _orderPrime[v][d] << ")" << endl; //DEBUG
				
				
				SddNode* alpha = sdd_disjoin(  sdd_manager_literal(-(_order[v][d]),_manager),  sdd_manager_literal(_orderPrime[v][d],_manager), _manager); // not X or X'
				sdd_ref(alpha,_manager);
				SddNode* beta = sdd_disjoin(  sdd_manager_literal(-(_orderPrime[v][d]),_manager),  sdd_manager_literal(_order[v][d],_manager), _manager); // not X' or X
				sdd_ref(beta,_manager);
				SddNode* gamma = sdd_conjoin( alpha, beta, _manager ); //(X <-> X')
				sdd_ref(gamma,_manager);
				sdd_deref(alpha,_manager);
				sdd_deref(beta,_manager);
				
				T = sdd_conjoin(tmp = T, gamma, _manager);
				sdd_ref(T,_manager);
				sdd_deref(tmp,_manager);
				sdd_deref(gamma,_manager);
				
				if (_verbose) cout << "  done" << endl;
			}
		}
	}
	if (_verbose) cout << "<<<<< Done generating Transition" << endl; //DEBUG
	
	//TODO those - I think - can easily be merged in one common for loop
	
	return T;
	
}

void SddSearch::print_statistics() const {

}

/**
  * Will do the bulk of the work.
  * Since very little of the provided functions of FastDownward can
  * be used this function will single-handedly calculate the plan.
  * 
  * Will only be called once and never return IN_PROGRESS.
  * (May change later)
  */
SearchStatus SddSearch::step() {
	cout << "Beginning symbolic algorithm" << endl;
	
	GlobalState ginit = g_initial_state();
	SddNode* tmp;
	
	int vars = _domains.size();
	
	SddNode* initialState = sdd_manager_true(_manager);
	for (int v = 0; v < vars; v++) {
		int val = ginit[v];
		for (int d = 0; d < _domains[v]; d++) {
			if (val == d) {
				if (_verbose) cout << "    Variable " << v << "_" << d << " (" << _order[v][d] << ") is initially true" << endl; //DEBUG
				initialState = sdd_conjoin(tmp = initialState, sdd_manager_literal(_order[v][d],_manager),_manager);
				sdd_ref(initialState,_manager);
				sdd_deref(tmp,_manager);
			} else {
				if (_verbose) cout << "    Variable " << v << "_" << d << " (" << _order[v][d] << ") is initially false" << endl; //DEBUG
				initialState = sdd_conjoin(tmp = initialState, sdd_manager_literal(-(_order[v][d]),_manager),_manager);
				sdd_ref(initialState,_manager);
				sdd_deref(tmp,_manager);
			}
		}
	}
	
	SddNode* goalCondition = sdd_manager_true(_manager);
	for (int v = 0; v < vars; v++) {
		int val = -1;
		int size_goal = g_goal.size();
		for (int j = 0; j < size_goal; j++) {
			if (g_goal[j].first == v) {
				val = g_goal[j].second;
				break;
			}
		}
		if (val >= 0) {
			if (_verbose) cout << "    Variable " << v << "_" << val << " (" << _order[v][val] << ") is part of goal" << endl; //DEBUG
			goalCondition = sdd_conjoin(tmp = goalCondition, sdd_manager_literal(_order[v][val],_manager),_manager);
			sdd_ref(goalCondition,_manager);
			sdd_deref(tmp,_manager);
		}
	}
	_open[0] = initialState;
	
	SddNode* reachedGoalStates = sdd_manager_false(_manager);
	int gmin = 0;
	
	while (!_open.empty()) {

		gmin = getGMin();
		
		//insert into closed list
		
		if (_closed.count(gmin) == 0) {
			_closed[gmin] = _open[gmin];
			sdd_ref(_closed[gmin],_manager);
		} else {
			SddNode* tmp;
			_closed[gmin] = sdd_disjoin(tmp = _closed[gmin],_open[gmin],_manager);
			sdd_ref(_closed[gmin],_manager);
			sdd_deref(tmp,_manager);
		}

		sdd_deref(reachedGoalStates,_manager);
		reachedGoalStates = sdd_conjoin(_open[gmin], goalCondition,_manager);
		sdd_ref(reachedGoalStates,_manager);

		if (!sdd_node_is_false(reachedGoalStates)) {
			//Solution Found!
			printf("Solution found!\n");
			printf("Cost is %d\n", gmin);
			break;
		}

		fillInSuccessors(gmin);

		

	}
	if (_open.empty()) {
		printf("no solution found\n");
		return FAILED;
	}
	
    vector<const GlobalOperator*> plan = reconstructPath(reachedGoalStates, gmin);
	set_plan(plan);
	
	sdd_manager_free(_manager); //TODO find a more appropriate spot for this (deconstructor?)
	
    return SOLVED;
}
	
	
/**
  * Will return the smallest cost g for which 
  * there is a non-empty entry in the open list.
  */
int SddSearch::getGMin() {

	if (_verbose) printf("Start looking for g\n"); //DEBUG

	int g = 0;
	while (_open.count(g) == 0) {
		g++;
	}
	if (_verbose) printf("return g = %d\n", g); //DEBUG
	return g;

}

	
/**
  * Will generate the image of a given layer g and all actions,
  * fill them into the respective bins in the open list,
  * and remove the layer g from the open list.
  */
void SddSearch::fillInSuccessors(int g) {

	if (_verbose) cout << "start fillInSuccessors(" << g << ")" << endl; //DEBUG
	
	//TODO: Some parameter checking. Is there even a layer g?
	SddNode* s = _open[g];
	
	SddNode* tmp;
	
	int vars = _domains.size();
	
	for(map<int,list<SddNode*>>::iterator it_m = _operators.begin(); it_m != _operators.end(); ++it_m) { // For every operator...
		int cost = it_m->first;
    	list<SddNode*> fixCostOperators = it_m->second;
		for(list<SddNode*>::iterator it_l = fixCostOperators.begin(); it_l != fixCostOperators.end(); ++it_l) {// ... of cost c ...
			SddNode* op = *it_l;
			sdd_ref(op,_manager);
			
			SddNode* newState = sdd_conjoin(s, op,_manager);
			sdd_ref(newState,_manager);
			sdd_deref(op,_manager);
			
			if (_verbose) cout << "Eliminate old states..." << endl; //DEBUG									// ... calculate the successors ...
			for (int v = 0; v < vars; v++) {
				for (int d = 0; d < _domains[v]; d++) {
					newState = sdd_exists(_order[v][d], tmp = newState,_manager);
					sdd_ref(newState,_manager);
					sdd_deref(tmp,_manager);
				}
			}
			if (_verbose) cout << "...done" << endl; //DEBUG
			
			if (_verbose) cout << "Reordering Variables..." << endl; //DEBUG
			newState = sdd_rename_variables(tmp = newState, _renameFromPrime, _manager);
			sdd_ref(newState,_manager);
			sdd_deref(tmp,_manager);
			if (_verbose) cout << "...done" << endl; //DEBUG
			
			if (_verbose) cout << "Filling in results... " << endl; //DEBUG										// ... and fill them into the open list
			if (!sdd_node_is_false(newState)) {
				if (_open.count(g+cost) <= 0) {
					if (_verbose) cout << "...into new field " << g+cost << endl; //DEBUG
					_open[g+cost] = newState;
				} else {
					if (_verbose) cout << "... into existing field " << g+cost << endl; //DEBUG
					_open[g+cost] = sdd_disjoin(tmp = _open[g+cost], newState,_manager);
					sdd_ref(_open[g+cost],_manager);
					sdd_deref(tmp,_manager);
				}
			}
			else if (_verbose) cout << "... discarding " << g+cost << endl; //DEBUG
		}
		
	}
	
	sdd_deref(_open[g],_manager);
	_open.erase(g);
	
}
	
	
	
	
/**
  * Reconstructs the path using the closed list and a list of applicable operators
  */
std::vector<const GlobalOperator*> SddSearch::reconstructPath(SddNode* reachedGoalStates, int g) {
	
	if (_verbose) cout << "start path reconstruction" << endl; //DEBUG
	
	std::vector<const GlobalOperator*> plan;
	
	SddNode* s = reachedGoalStates;
	
	SddNode* tmp;
	
	
	while (g > 0) {																										// Until the start is reached ...
		
		if (_verbose) cout << "Reordering Variables..." << endl; //DEBUG												// ... prime all variables ...
		s = sdd_rename_variables(tmp = s, _renameToPrime, _manager);
		sdd_ref(s,_manager);
		sdd_deref(tmp,_manager);
		if (_verbose) cout << "...done" << endl; //DEBUG
		
		bool found = false;

		for(map<int,list<SddNode*>>::iterator it_m = _operators.begin(); it_m != _operators.end(); ++it_m) {			// ... and for every transition T ...
			if (_verbose) cout << "g = " << g << endl; //DEBUG
			int cost = it_m->first;
			list<SddNode*> fixCostOperators = it_m->second;
			
			if (_closed.count(g-cost) > 0) {
			
				for(list<SddNode*>::iterator it_l = fixCostOperators.begin(); it_l != fixCostOperators.end(); ++it_l) { 	// ... of cost c ...
					SddNode* op = *it_l;

					SddNode* predecessors = sdd_conjoin(s, op, _manager);													// ... calculate the predecessors of T ...
					sdd_ref(predecessors,_manager);
					//---------------------------------------------------------------------------------------
					if (_verbose) cout << "Eliminate primed states..." << endl; //DEBUG
					int vars = _domains.size();
					for (int v = 0; v < vars; v++) {
						for (int d = 0; d < _domains[v]; d++) {
							predecessors = sdd_exists(_orderPrime[v][d], tmp = predecessors,_manager);
							sdd_ref(predecessors,_manager);
							sdd_deref(tmp,_manager);
						}
					}
					if (_verbose) cout << "...done" << endl; //DEBUG

					predecessors = sdd_conjoin(tmp = predecessors, _closed[g-cost], _manager);								// ... cut with the entries in the closed list.
					sdd_ref(predecessors,_manager);
					sdd_deref(tmp,_manager);

					if (!sdd_node_is_false(predecessors)) {																	// If an overlap was found ...
						if (_verbose) cout << " accepting operator" << endl; //DEBUG

						plan.push_back(_operatorCorrespondence[op]);														// ... save the operator that made it happen ...

						sdd_deref(s,_manager);
						s = predecessors;																					// ... and restart the loop with the new found predecessors ...
						//sdd_ref(s,_manager);
						g = g-cost;																							// ... and cost reduced by c.

						found = true; //set to true so you break out of both loops
						break;

					} else {
						if (_verbose) cout << " discard operator" << endl; //DEBUG
					}
				}
			}
			
			if (found) break;
		}
		if (!found) {
			//Loop over all operators of all costs finished without finding a single overlap of the predecessors with the closed list.
			 cout << "!!!!!!!!!!!!!!!! Could not find applicable operator. This should never happen." << endl;
		}
			
	}
	
	if (_verbose) cout << "end path reconstruction" << endl; //DEBUG
	std::reverse(plan.begin(), plan.end());
	return plan;
}

static SearchEngine *_parse(OptionParser &parser) {
    parser.document_synopsis("Symbolic sdd search", "");
    parser.add_option<std::string>("ordertype",
                            "simple, primedfirst or densedomains", "simple");
    parser.add_option<std::string>("order",
                            "balanced, right, left or vertical", "balanced");
    parser.add_option<bool>("verbose",
                            "displays debug messages", "false");
    SearchEngine::add_options_to_parser(parser);
    Options opts = parser.parse();

    SddSearch *engine = nullptr;
    if (!parser.dry_run()) {
        engine = new SddSearch(opts);
    }

    return engine;
}

static Plugin<SearchEngine> _plugin("sddsearch", _parse);
}
	
	