// -*- mode: C++; c-file-style: "stroustrup"; c-basic-offset: 4; -*-
////////////////////////////////////////////////////////////////////
//
// $Id: interference.cpp 935 2016-05-27 10:23:55Z Martin Wehrle $
//
////////////////////////////////////////////////////////////////////

#include "interference.h"
#include "transition.h"

#include "common/option.h"

#include "system/assignment.h"
#include "system/edge.h"
#include "system/effect.h"
#include "system/guard.h"
#include "system/location.h"
#include "system/process.h"
#include "system/state.h"
#include "system/system.h"
#include "system/target.h"
#include "system/task.h"

#include <cassert>
#include <ext/hash_set>

using namespace std;

InterferenceFilter::InterferenceFilter(const Task* task, const Options* opts) :
    task(task),
    opts(opts),
    interference_level(task->system->builder->getField("interference level")) {
    assert(interference_level != -1);
    generate_closure();
}

void InterferenceFilter::generate_closure() {
    // NOTE: inter[i][j] == n iff trans_i needs at least n transitions
    // in a chain to interfere with trans_j

    TransitionBuilder& builder = TransitionBuilder::builder(task->system);
    const uint32_t size = builder.getNrTransitions();
    inter.assign(size, vector<int32_t>(size, INT_MAX));

    for (uint32_t i = 0; i < size; i++) {
        // FIXME: perhaps 0 can be replaced with i+1
        for (uint32_t j = 0; j < size; j++) {
            if (i == j) {
                inter[i][j] = 0;
            } else if (interfere(builder.getAllTransitions()[i], builder.getAllTransitions()[j])) {
                inter[i][j] = 1;
            }
        }
    }

    // floyd warshall: all pairs shortest path
    max_closure = 0;
    for (uint32_t k = 0; k < size; k++) {
        for (uint32_t i = 0; i < size; i++) {
            for (uint32_t j = 0; j < size; j++) {
                if (inter[i][k] != INT_MAX && inter[k][j] != INT_MAX) {
                    inter[i][j] = std::min(inter[i][j], inter[i][k] + inter[k][j]);
                    if (max_closure < inter[i][j]) {
                        max_closure = inter[i][j];
                    }
                }
            }
        }
    }
}

bool InterferenceFilter::innocent(const Transition* trans) const {
    static vector<int32_t> is_innocent(TransitionBuilder::builder(task->system).getNrTransitions(), -1);
    if (is_innocent[trans->uid] != -1) {
        return is_innocent[trans->uid] == 1;
    }

    if (!innocentLocations(trans->edge1) ||
        (trans->edge2 && !innocentLocations(trans->edge2)) ||
        !innocentIntegers(trans->edge1) ||
        (trans->edge2 && !innocentIntegers(trans->edge2))) {
        is_innocent[trans->uid] = 0;
    } else {
        is_innocent[trans->uid] = 1;
    }
    return is_innocent[trans->uid] == 1;
}

bool InterferenceFilter::innocentLocations(const Edge* edge) const {
    for (uint32_t i = 0; i < task->target->getLocationConstraints().size(); i++) {
        if (edge->dst->idInSystem == task->target->getLocationConstraints()[i]->loc->idInSystem) {
            return false;
        }
    }

    return true;
}

bool InterferenceFilter::innocentIntegers(const Edge* edge) const {
    for (uint32_t i = 0; i < edge->effect->intassigns.size(); i++) {
        const IntAssignment* ass = edge->effect->intassigns[i];
        for (uint32_t j = 0; j < task->target->getIntConstraints().size(); j++) {
            if (ass->lhs->id == task->target->getIntConstraints()[j]->lhs->id) {
                return false;
            }
        }
    }
    return true;
}

static inline bool is_consistent_with(int32_t v, const IntConstraint* cons) {
    switch (cons->comp) {
    case Constraint::LT:
        return v < cons->rhs;
    case Constraint::LE:
        return v <= cons->rhs;
    case Constraint::EQ:
        return v == cons->rhs;
    case Constraint::GE:
        return v >= cons->rhs;
    case Constraint::GT:
        return v > cons->rhs;
    case Constraint::NEQ:
        return v != cons->rhs;
    default:
        assert(false);
        return false;
    }
}

// independet of whether edge2 is applicable in successor returns true
// iff there is an updated variable v in edge1 and a constraint v comp
// c in edge2 such that this constraint is not satisfied
static inline bool disables(const Edge* edge1, const Edge* edge2, const State* state) {
    // locations
    if (edge1->dst->proc->id == edge2->dst->proc->id) {
        if (edge1->dst->idInSystem != edge2->src->idInSystem) {
            return true;
        }
    }

    // // synchronization
    // if ((edge1->getType() == Edge::BANG && edge2->getType() == Edge::QUE) ||
    //  (edge2->getType() == Edge::BANG && edge1->getType() == Edge::QUE)) {

    //  if (edge1->getAction()->id == edge2->getAction()->id)
    //      return true;
    // }

    // integer variables
    for (uint32_t i = 0; i < edge1->effect->intassigns.size(); i++) {
        IntAssignment* ia = edge1->effect->intassigns[i];
        const ValueAssignment* va = dynamic_cast<ValueAssignment*>(ia);

        int32_t rhs_value = 0;
        uint32_t lhs_id = 0;
        if (va) {
            rhs_value = va->rhs;
            lhs_id = va->lhs->id;
        } else {
            const VarAssignment* va = dynamic_cast<VarAssignment*>(ia);
            assert(va);
            lhs_id = va->lhs->id;
            rhs_value = state->var(va->rhs->id);
        }

        for (uint32_t j = 0; j < edge2->guard->intconstraints.size(); j++) {
            const IntConstraint* ic = edge2->guard->intconstraints[j];
            if (ic->lhs->id == lhs_id && !is_consistent_with(rhs_value, ic)) {
                return true;
            }
        }
    }

    // clocks are not handled, yet
    // for (uint32_t i = 0; i < edge1->effect->resets.size(); i++) {
    //  for (uint32_t j = 0; j < edge2->guard->clockconstraints.size(); j++) {

    //      if (edge2->guard->clockconstraints[j]->comp == Constraint::GT && edge2->guard->clockconstraints[j]->rhs == 0)
    //          if (edge1->effect->resets[i]->lhs->id == edge2->guard->clockconstraints[j]->lhs->id)
    //              return true;

    //  }
    // }

    return false;
}

// Assumption: trans1 is applied in state.
// Returns: true iff after applying trans1 to state, trans2 is not applicable
static inline bool disables(const Transition* trans1, const Transition* trans2, const State* state) {
    if (disables(trans1->edge1, trans2->edge1, state))
        return true;
    if (trans2->edge2 && disables(trans1->edge1, trans2->edge2, state))
        return true;
    if (trans1->edge2 && disables(trans1->edge2, trans2->edge1, state))
        return true;
    if (trans1->edge2 && trans2->edge2 && disables(trans1->edge2, trans2->edge2, state))
        return true;

    return false;
}

//bool InterferenceFilter::is_enabled_in(const Edge* edge, const State* state) const {

static inline bool is_enabled_in(const Edge* edge, const State* state) {
    uint32_t curr_loc = state->proc(edge->getProcess()->id);
    if (curr_loc != edge->src->idInProcess) {
        return false;
    }
    return edge->guard->isSatBy(state);
}

bool InterferenceFilter::all_relevant_enabled(const State* state, const vector<State*>& succs) const {
    // enabled = succs(state) --> enabled_transitions
    const TransitionBuilder& builder = TransitionBuilder::builder();

    __gnu_cxx::hash_set<uint32_t> enabled;
    for (uint32_t i = 0; i < succs.size(); i++) {
        enabled.insert(succs[i]->reachedby->uid);
    }

    const vector<int32_t>& inter_with_t = inter[state->reachedby->uid];
    for (uint32_t i = 0; i < inter_with_t.size(); i++) {
        const Transition* ti = builder.getTransition(i);
        if (inter_with_t[i] <= 1 &&
            enabled.find(ti->uid) == enabled.end() &&
            !disables(state->reachedby, ti, state)) {
            return false;
        }
    }
    return true;
}

static inline bool enables_or_disables(const Edge* edge1, const Edge* edge2) {
    for (uint32_t i = 0; i < edge1->effect->intassigns.size(); i++) {
        const Integer* var = edge1->effect->intassigns[i]->lhs;
        for (uint32_t j = 0; j < edge2->guard->intconstraints.size(); j++) {
            if (var->id == edge2->guard->intconstraints[j]->lhs->id) {
                return true;
            }
        }

        for (uint32_t j = 0; j < edge2->effect->intassigns.size(); j++) {
            VarAssignment* va = dynamic_cast<VarAssignment*>(edge2->effect->intassigns[j]);
            if (va && var->id == va->rhs->id)
                return true;
        }

        // TODO: think about clocks more carefully
        // for (uint32_t i = 0; i < edge1->effect->resets.size(); i++)
        //     for (uint32_t j = 0; j < edge2->guard->clockconstraints.size(); j++)
        //      if (edge1->effect->resets[i]->lhs->id == edge2->guard->clockconstraints[j]->lhs->id)
        //          return true;
    }

    return false;
}

static inline bool writes_the_same_variable(const Edge* edge1, const Edge* edge2) {
    if (edge1->dst->proc->id == edge2->dst->proc->id) {
        return true;
    }

    for (uint32_t i = 0; i < edge1->effect->intassigns.size(); i++) {
        for (uint32_t j = 0; j < edge2->effect->intassigns.size(); j++) {
            if (edge1->effect->intassigns[i]->lhs->id == edge2->effect->intassigns[j]->lhs->id) {
                return true;
            }
        }
    }
    // Clocks can be ignored: If both x and y are set to zero, these
    // operations can be applied in any order, leading to the same
    // discrete state in all cases; only the zones can differ (x<=y
    // vs. y <= x). As we do not support difference constraints,
    // turning the order cannot lead to a different semantics.

    return false;
}

bool InterferenceFilter::interfere(const Edge* edge1, const Edge* edge2) const {
    if (edge1->dst->proc->id == edge2->dst->proc->id) {
        return true;
    }

    if (enables_or_disables(edge1, edge2) ||
        enables_or_disables(edge2, edge1) ||
        writes_the_same_variable(edge1, edge2)) {
        return true;
    }

    return false;
}

bool InterferenceFilter::interfere(const Transition* trans1, const Transition* trans2) const {
    if (interfere(trans1->edge1, trans2->edge1))
        return true;
    if (trans1->edge2 && interfere(trans1->edge2, trans2->edge1))
        return true;
    if (trans2->edge2 && interfere(trans1->edge1, trans2->edge2))
        return true;
    if (trans1->edge2 && trans2->edge2 && interfere(trans1->edge2, trans2->edge2))
        return true;
    return false;
}

void InterferenceFilter::label(vector<State*>& succs) const {
    if (succs.empty()) {
        return;
    }
    const State* parent = succs.front()->predecessor;
    if (!parent->reachedby || !innocent(parent->reachedby)) {
        return;
    }
    if (opts->ce_closure) {
        for (uint32_t i = 0; i < succs.size(); i++) {
            if (inter[parent->reachedby->uid][succs[i]->reachedby->uid] == INT_MAX) {
                succs[i]->extra(interference_level) = INT_MAX;
            }
        }
    } else if ((!opts->ce_with_enabled || all_relevant_enabled(parent, succs))) {
        for (int32_t IP = max_closure; IP > 0; IP--) {
            for (uint32_t i = 0; i < succs.size(); i++) {
                int32_t ivalue = inter[parent->reachedby->uid][succs[i]->reachedby->uid];
                if (0 < ivalue && ivalue > IP && succs[i]->extra(interference_level) == -1) {
                    succs[i]->extra(interference_level) = IP;
                }
            }
        }
    }
}

bool InterferenceFilter::interferes(const Transition* t1, const Transition* t2, int32_t bound) const {
    return inter[t1->uid][t2->uid] <= bound;
}
