// -*- mode: C++; c-file-style: "stroustrup"; c-basic-offset: 4; -*-
////////////////////////////////////////////////////////////////////
//
// $Id: dtg.cpp 942 2016-05-27 12:58:52Z Martin Wehrle $
//
////////////////////////////////////////////////////////////////////

#include "causalgraph/domainTransitionGraph.h"
#include "causalgraph/dtg.h"
#include "causalgraph/cg_operator.h"
#include "causalgraph/cg_variable.h"
#include "causalgraph/scc.h"

#include "common/message.h"

#include <iostream>
#include <map>
#include <cassert>

using namespace std;

typedef cg::DomainTransitionGraph::Transition Transition;

namespace cg {
DTG::DTG(const CGVariable* cgvar) :
    cgvar(cgvar),
    var(cgvar->level) {
    int32_t node_count = cgvar->upper - cgvar->lower;
    // nodes = possible values for variable var_index
    nodes.reserve(node_count + 1);
    for (int32_t value = 0; value <= node_count; value++) {
        nodes.push_back(ValueNode(this, value));
    }
}


void DTG::construct_transitions(DomainTransitionGraph* graph, vector<DTG*>& g_transition_graphs, vector<CGOperator*>& operators) {
    map<int32_t, int32_t> global_to_local_child;
    map<pair<int32_t, int32_t>, int32_t> transition_index;
    // TODO: This transition index business is caused by the fact
    //       that transitions in the input are not grouped by target
    //       like they should be. Change this.

    for (uint32_t origin = 0; origin < nodes.size(); origin++) {
        vector<cg::DomainTransitionGraph::Transition> vertices = graph->get_vertices(origin);
        uint trans_count = vertices.size();     // nr of outgoing transitions with source origin

        for (uint32_t i = 0; i < trans_count; i++) {
            const Transition& trans = vertices[i];
            int32_t target = trans.target;
            int32_t operator_index = trans.op;

            pair<int32_t, int32_t> arc = make_pair(origin, target);
            if (!transition_index.count(arc)) {
                transition_index[arc] = nodes[origin].transitions.size();
                nodes[origin].transitions.push_back(ValueTransition(&nodes[target]));
            }

            assert(transition_index.count(arc));
            ValueTransition* transition = &nodes[origin].transitions[transition_index[arc]];

            vector<PrevailCondition> prevail;

            int32_t global_var;
            int32_t val;

            for (uint32_t k = 0; k < trans.condition.size(); k++) {
                global_var = -1;
                val = -1;
                if (trans.condition[k].first->level != -1) {
                    global_var = trans.condition[k].first->level;

                    val = trans.condition[k].second;
                    assert(global_var != -1);
                    assert(val != -1);

                    // TODO: think more about this... do we need it ?? Cycles are already ignored in DomainTransitionGraph
                    // Status: should be always true because of the above reason
                    assert(global_var < var);

                    if (global_var < var) {    // [ignore cycles]
                        if (!global_to_local_child.count(global_var)) {
                            global_to_local_child[global_var] = local_to_global_child.size();
                            local_to_global_child.push_back(global_var);
                        }
                        int32_t local_var = global_to_local_child[global_var];

                        DTG* prev_dtg = g_transition_graphs[global_var];
                        prevail.push_back(PrevailCondition(prev_dtg, local_var, val));
                    }
                }
            }

            CGOperator* the_operator;
            assert(operator_index >= 0 && (uint32_t)operator_index < operators.size());
            the_operator = operators[operator_index];
            transition->labels.push_back(ValueTransitionLabel(the_operator, prevail));
        }
    }
}

bool DTG::is_strongly_connected() {
    vector<vector<int32_t>> unweighted_graph;
    for (uint32_t i = 0; i < nodes.size(); i++) {
        ValueNode node = nodes[i];
        vector<int32_t> succs;
        for (uint32_t j = 0; j < node.transitions.size(); j++) {
            ValueTransition transition = node.transitions[j];
            ValueNode* target = transition.target;
            succs.push_back(target->value);
        }
        unweighted_graph.push_back(succs);
    }

    vector<vector<int32_t>> int_result = SCC(unweighted_graph).get_result();

    if (int_result.size() == 1) {
        return true;
    } else {
        return false;
    }
}

// check if all guards of the dtg do only contain variables from some_variables
bool DTG::uses_only_these_vars(vector<CGVariable*>& some_variables) {
    // every guard variable of the dtg must be contained in some_variables
    for (uint32_t i = 0; i < nodes.size(); i++) {
        for (uint32_t j = 0; j < nodes[i].transitions.size(); j++) {
            ValueTransition trans = nodes[i].transitions[j];

            for (uint32_t k = 0; k < trans.labels.size(); k++) {
                ValueTransitionLabel label = trans.labels[k];
                vector<PrevailCondition> prevail = label.prevail;
                for (uint32_t q = 0; q < prevail.size(); q++) {
                    int32_t local_var = prevail[q].local_var;
                    int32_t global_var = local_to_global_child[local_var];
                    bool found = false;
                    for (uint32_t r = 0; r < some_variables.size() && !found; r++) {
                        if (some_variables[r]->level == global_var) {
                            found = true;
                        }
                    }
                    if (!found) {
                        return false;
                    }
                }
            }
        }
    }

    return true;
}

ostream& DTG::display(ostream& o) const {
    o << "DTG for variable " << var << ", number of nodes =  " << nodes.size() << endl;
    for (uint32_t i = 0; i < nodes.size(); i++) {
        o << "  " << i << ". node: value = " << nodes[i].value << ", number of outgoing transitions " << nodes[i].transitions.size() << endl;
        for (uint32_t j = 0; j < nodes[i].transitions.size(); j++) {
            o << "    " << j << ". transition: number of labels: " << nodes[i].transitions[j].labels.size() << endl;
            for (uint32_t k = 0; k < nodes[i].transitions[j].labels.size(); k++) {
                o << "  new label starts" << endl;
                for (uint32_t l = 0; l < nodes[i].transitions[j].labels[k].prevail.size(); l++)
                    o << "     condition (var,val) = " << nodes[i].transitions[j].labels[k].prevail[l].local_var << "," << nodes[i].transitions[j].labels[k].prevail[l].value << endl;
            }
        }
    }
    return o;
}
}
