#include "states.h"

#include "evaluatables.h"
#include "search_engine.h"

#include "utils/string_utils.h"
#include "utils/math_utils.h"

#include <sstream>

using namespace std;

string State::toCompactString() const {
    stringstream ss;
    for (double val : deterministicStateFluents) {
        ss << val << " ";
    }
    ss << "| ";
    for (double val : probabilisticStateFluents) {
        ss << val << " ";
    }
    return ss.str();
}

string State::toCompactStringForFile() const {
    stringstream ss;
    for (double val : deterministicStateFluents) {
        ss << val << " ";
    }
    for (double val : probabilisticStateFluents) {
        ss << val << " ";
    }
    return ss.str();
}

string State::toString() const {
    stringstream ss;
    for (size_t i = 0; i < State::numberOfDeterministicStateFluents; ++i) {
        ss << SearchEngine::deterministicCPFs[i]->name << ": "
           << deterministicStateFluents[i] << endl;
    }
    ss << endl;
    for (size_t i = 0; i < State::numberOfProbabilisticStateFluents; ++i) {
        ss << SearchEngine::probabilisticCPFs[i]->name << ": "
           << probabilisticStateFluents[i] << endl;
    }
    ss << "Remaining Steps: " << remSteps << endl
       << "StateHashKey: " << hashKey << endl;
    return ss.str();
}

string PDState::toCompactString() const {
    stringstream ss;
    for (double val : deterministicStateFluents) {
        ss << val << " ";
    }
    for (DiscretePD const& pd : probabilisticStateFluentsAsPD) {
        ss << pd.toString();
    }
    return ss.str();
}

namespace {
void expandPDState(PDState state, double prob, int index, std::vector<std::pair<PDState, double>> &result) {
    while (index < state.numberOfProbabilisticStateFluents && 
        state.probabilisticStateFluentAsPD(index).isDeterministic()) {
            state.probabilisticStateFluent(index) = state.probabilisticStateFluentAsPD(index).values[0];
            ++index;
    }
    if (index == state.numberOfProbabilisticStateFluents) {
        result.push_back(make_pair(state, prob));
    } else {
        const DiscretePD &varVal = state.probabilisticStateFluentAsPD(index);
        for (size_t i = 0; i < varVal.values.size(); ++i) {
            state.probabilisticStateFluent(index) = varVal.values[i];
            expandPDState(state, prob*varVal.probabilities[i], index + 1, result);
        }
    }
}
}

std::vector<std::pair<PDState, double>> PDState::expand() const {
std::vector<std::pair<PDState, double>> result;
expandPDState(*this, 1.0, 0, result);
return result;
}


bool PDState::getNumberOfPDSuccessors(int threshold) const{
    long result = 1;
    for(int index = 0; index < numberOfProbabilisticStateFluents; index++){
        if((!MathUtils::multiplyWithOverflowCheck(result, probabilisticStateFluentAsPD(index).size())) || (result > threshold)){
            return false;
        }   
    }
    return true;
}

//samples PDState outcomes if there are to many successors to expand 
std::vector<std::pair<PDState, double>> PDState::sampleOutcomes(int numberOfSamples) const{
  
    std::set<pair<PDState, double>, PDStateComparisonInequality> outcomes;
    double accumulatedProb = 0.0;
    for(int index = 0; index < numberOfSamples; index++){
        PDState copy(*this);
        std::pair<double,double> outcome;
        double prob = 1.0;
        for(int pdDigit = 0; pdDigit < numberOfProbabilisticStateFluents; pdDigit++){
            outcome = copy.sample(pdDigit);
            prob *= outcome.second; 

            //safe last entry and check for duplicates
            if((pdDigit == numberOfProbabilisticStateFluents-1)){
                if(get<1>(outcomes.insert(make_pair(copy, prob)))){
                    accumulatedProb += prob;
                }
            }
        }
    }
    
    if(accumulatedProb == 0){
        cout << "accumulatedProb = 0" << endl;
    }

    //normalize probabilities
    std::vector<std::pair<PDState, double>> result;

    for(std::pair<PDState, double> const& singlePair : outcomes){
        result.push_back(make_pair(singlePair.first, singlePair.second/accumulatedProb));
    }
    return result;
}


string KleeneState::toString() const {
    stringstream ss;
    for (unsigned int index = 0; index < KleeneState::stateSize; ++index) {
        ss << SearchEngine::allCPFs[index]->name << ": { ";
        for (double val : state[index]) {
            ss << val << " ";
        }
        ss << "}" << endl;
    }
    return ss.str();
}

vector<string> ActionState::getScheduledActionFluentNames() const {
    vector<string> varNames;
    for (size_t i = 0; i < state.size(); ++i) {
        int valueIndex = state[i];
        if (!SearchEngine::actionFluents[i]->isFDR) {
            if (valueIndex) {
                varNames.push_back(SearchEngine::actionFluents[i]->name);
            }
        } else if (SearchEngine::actionFluents[i]->values[valueIndex] !=
                   "none-of-those") {
            varNames.push_back(
                SearchEngine::actionFluents[i]->values[valueIndex]);
        }
    }
    return varNames;
}

string ActionState::toCompactString() const {
    vector<string> varNames = getScheduledActionFluentNames();
    if (varNames.empty()) {
        return "noop()";
    }
    stringstream ss;
    for (string const& varName : varNames) {
        ss << varName << " ";
    }
    return ss.str();
}

string ActionState::toString() const {
    stringstream ss;
    ss << toCompactString() << ": " << endl
       << "Index : " << index << endl
       << "Relevant preconditions: " << endl;
    for (DeterministicEvaluatable const* precond : actionPreconditions) {
        ss << precond->name << endl;
    }
    return ss.str();
}
