// -*- mode: C++; c-file-style: "stroustrup"; c-basic-offset: 4; -*-
////////////////////////////////////////////////////////////////////
//
// $Id: context.cpp 933 2016-05-27 10:17:48Z Martin Wehrle $
//
////////////////////////////////////////////////////////////////////

#include "context.h"
#include "successor_generator.h"
#include "normalizer.h"
#include "open_list.h"
#include "closed_list.h"
#include "heuristic.h"
#include "transition.h"
#include "interference.h"

#include "common/message.h"
#include "common/option.h"

#include "system/process.h"
#include "system/state.h"
#include "system/target.h"
#include "system/task.h"

using namespace std;

// NOTE: for the icb search it is important that two states are equal
// only if they are reahced via the same transition. (To minimize the
// context switches, otherwise this is not guaranteed!)

class ContextClosedList : public ClosedList {
public:
    bool equal(const State* s1, const State* s2) const {
        return s1->reachedby == s2->reachedby && s1->discEqual(s2);
    }
};

////////////////////////////////////////////////////////////////////

class Context {
public:
    virtual ~Context() {}
    virtual void set(const State*) = 0;
    virtual bool contains(const State*) const = 0;
};

////////////////////////////////////////////////////////////////////

class SingleContext : public Context {
public:
    void set(const State*) {}
    bool contains(const State*) const {return false; }
};

////////////////////////////////////////////////////////////////////

class AllContext : public Context {
public:
    void set(const State*) {}
    bool contains(const State*) const {return true; }
};

////////////////////////////////////////////////////////////////////

class ProcessContext : public Context {
protected:
    const Process* context1;
    const Process* context2;

    bool contains(const Edge* edge) const {
        assert(edge);
        assert(context1);
        if (context2) {
            return edge->getProcess()->id == context1->id || edge->getProcess()->id == context2->id;
        } else {
            return edge->getProcess()->id == context1->id;
        }
    }
public:
    ProcessContext() : context1(NULL), context2(NULL) {}

    void set(const State* state) {
        assert(state && state->reachedby);
        context1 = state->reachedby->edge1->getProcess();
        if (state->reachedby->edge2) {
            context2 = state->reachedby->edge2->getProcess();
        } else {
            context2 = NULL;
        }
    }

    bool contains(const State* state) const {
        assert(state->reachedby);
        assert(context1);
        if (!contains(state->reachedby->edge1)) {
            return false;
        }
        if (state->reachedby->edge2 && !contains(state->reachedby->edge2)) {
            return false;
        }
        return true;
    }
};

////////////////////////////////////////////////////////////////////

class InterferenceContext : public Context {
private:
    InterferenceFilter ip;
    const Transition* context;
    uint32_t distance;
public:
    InterferenceContext(const Task* task, const Options* opts) :
        ip(task, opts),
        context(NULL),
        distance(opts->ce_dist)
    {}

    virtual void set(const State* state) {
        context = state->reachedby;
    }

    virtual bool contains(const State* state) const {
        assert(context);
        return ip.interferes(context, state->reachedby, distance);
    }
};

////////////////////////////////////////////////////////////////////

template<class SE>
class CSE : public SE {
private:
    OpenList** next;
    Context* context;
public:
    CSE(const Task* task, const Options* opts, OpenList** next) :
        SE(task, opts),
        next(next),
        context(NULL) {
        switch (opts->icb) {
        case ICB_PROC:
            context = new ProcessContext;
            break;
        case ICB_INTERFERENCE:
            context = new InterferenceContext(task, opts);
            break;
        case ICB_SINGLE:
            context = new SingleContext;
            break;
        case ICB_ALL:
            context = new AllContext;
            break;
        default:
            error() << "No such option" << endl;
        }
    }

    void prepareExploration(const SuccessorGenerator&, State* initial) {
        assert(SE::open->empty());
        SE::open->insert(initial, 0, SE::heur->eval(initial));
        context->set(initial);
    }

    void insert(State* succ) {
        if (context->contains(succ)) {
            SE::insert(succ);
        } else {
            swap(*next, SE::open);
            SE::insert(succ);
            swap(*next, SE::open);
        }
    }
};

////////////////////////////////////////////////////////////////////

void ContextSearchEngine::init(OpenList* o, ClosedList* c, Heuristic* h, Normalizer* n) {
    SearchEngine::init(o, c, h, n);

    closed = new ContextClosedList;
    static OpenList* dummy = open->clone();
    next = &dummy;

    if (opts->ce && opts->search == UT_SEARCH) {
        cse = new CSE<UTIPSearchEngine>(task, opts, next);
    } else if (opts->ce) {
        cse = new CSE<IPSearchEngine>(task, opts, next);
    } else if (opts->search == UT_SEARCH) {
        cse = new CSE<UTSearchEngine>(task, opts, next);
    } else {
        cse = new CSE<SearchEngine>(task, opts, next);
    }
    cse->init(open->clone(), closed, heur, norm);
}

bool ContextSearchEngine::generateInitialSuccessors(State* initial, vector<State*>& initial_succs) {
    SuccessorGenerator succgen(task, opts, indep_size);
    prepareExploration(succgen, initial);
    State* state = open->get();
    ++exploredStates;
    printProgress(std::cerr);
    if (task->target->isSatBy(state)) {
        stat = SAT;
        goal = state;
        return false;
    }
    norm->normalize(state);
    state->zone().intern();
    closed->insert(state);
    succgen.generateSuccessors(state, initial_succs);
    generatedStates += initial_succs.size();
    return true;
}

void ContextSearchEngine::explore(State* initial) {
    vector<State*> initial_succs;
    if (!generateInitialSuccessors(initial, initial_succs)) {
        return;
    }
    for (uint32_t i = 0; i < initial_succs.size(); i++) {
        insert(initial_succs[i]);
    }
    // this->open and this->closed now contains the immediate successors of
    // initial

    uint32_t context_switches = 0;
    for (;;) {
        while (!open->empty()) {
            State* state = open->get();
            assert(cse->open->empty());
            cse->explore(state);

            exploredStates = cse->exploredStates;
            generatedStates = cse->generatedStates;
            stat = cse->stat;
            goal = cse->goal;

            if (stat == SAT) {
                debug() << "Context switches: " << context_switches << endl;
                return;
            }
            assert(cse->open->empty());
        }
        if ((*next)->empty()) {
            stat = UNSAT;
            debug() << "Context switches: " << context_switches << endl;
            return;
        }
        //cout << cse->open->size() << " " << closed->size() << " " << (*next)->size() << endl;
        context_switches++;
        assert(open->empty());
        assert(cse->open->empty());
        swap(open, *next);
    }
}
