// -*- mode: C++; c-file-style: "stroustrup"; c-basic-offset: 4; -*-
////////////////////////////////////////////////////////////////////
//
// $Id: constraint.h 941 2016-05-27 12:47:37Z Martin Wehrle $
//
////////////////////////////////////////////////////////////////////

#ifndef SYSTEM_CONSTRAINT_H
#define SYSTEM_CONSTRAINT_H

#include "expression.h"
#include "varset.h"

#include "common/flyweight.h"
#include "common/helper.h"

#include <iosfwd>
#include <vector>

class Clock;
class Integer;
class Constant;
class State;
class Location;

class Constraint : public Expression {
public:
    enum comp_t {LT = 0, LE, EQ, GE, GT, NEQ};
    comp_t comp;
    int32_t id;

    virtual bool isSatBy(const State* state) const = 0;
    virtual varset_t writeVars() const {return varset_t(); }
    virtual varset_t readVars() const = 0;
protected:
    Constraint(comp_t comp, Expression::type_t type, int32_t id);
};

////////////////////////////////////////////////////////////////////

class ClockConstraint : public Constraint {
private:
    ClockConstraint(Clock* lhs, comp_t comp, int32_t rhs, int32_t id);
    ClockConstraint(Clock* lhs, comp_t comp, const Constant* rhs, int32_t id);
public:
    Clock* lhs;
    int32_t rhs;

    virtual bool operator==(const ClockConstraint& cc) const;
    virtual bool isSatBy(const State* state) const;
    virtual varset_t readVars() const;

    bool constrain(State* state) const;
    virtual std::ostream& display(std::ostream& o) const;

    friend class ConstraintFactory;
};

////////////////////////////////////////////////////////////////////

class IntConstraint : public Constraint {
private:
    IntConstraint(Integer* lhs, comp_t comp, int32_t rhs, int32_t id);
    IntConstraint(Integer* lhs, comp_t comp, const Constant* rhs, int32_t id);
public:
    Integer* lhs;
    int32_t rhs;

    virtual varset_t readVars() const;
    virtual bool operator==(const IntConstraint& ic) const;
    virtual bool isSatBy(const State* state) const;
    virtual std::ostream& display(std::ostream& o) const;

    friend class ConstraintFactory;
};

////////////////////////////////////////////////////////////////////

class LocationConstraint : public Constraint {
private:
    LocationConstraint(Location* loc, int32_t id);
public:
    Location* loc;

    virtual varset_t readVars() const;

    virtual bool operator==(const LocationConstraint& lc) const;
    virtual bool isSatBy(const State* state) const;
    virtual std::ostream& display(std::ostream& o) const;

    friend class ConstraintFactory;
};

////////////////////////////////////////////////////////////////////

class ConstraintFactory {
private:
    Flyweight<IntConstraint> ints;
    Flyweight<ClockConstraint> clocks;
    Flyweight<LocationConstraint> locs;

    ConstraintFactory() {}
public:
    static ConstraintFactory& getFactory() {
        static ConstraintFactory theFactory;
        return theFactory;
    }

    IntConstraint* create(Integer* lhs, Constraint::comp_t comp, int32_t rhs);
    IntConstraint* create(Integer* lhs, Constraint::comp_t comp, const Constant* rhs);

    ClockConstraint* create(Clock* lhs, Constraint::comp_t comp, int32_t rhs);
    ClockConstraint* create(Clock* lhs, Constraint::comp_t comp, const Constant* rhs);

    LocationConstraint* create(Location* loc);

    uint32_t totalNrIntConstraints() const {return ints.size(); }
    uint32_t totalNrLocationConstraints() const {return locs.size(); }
    uint32_t totalNrClockConstraints() const {return clocks.size(); }
};

////////////////////////////////////////////////////////////////////

template<class T>
class Conjunction {
private:
    std::vector<T*> constraints;
public:
    typedef typename std::vector<T*>::const_iterator const_iterator;
    typedef typename std::vector<T*>::iterator iterator;

    Conjunction() : constraints() {}

    ~Conjunction() {deleteVector(constraints); }

    void add(T* constraint) {
        constraints.push_back(constraint);
    }

    void erase(int i) {
        constraints.erase(constraints.begin() + i);
    }

    void resize(int i) {
        constraints.resize(i);
    }

    iterator begin() {return constraints.begin(); }
    const_iterator begin() const {return constraints.begin(); }

    iterator end() {return constraints.end(); }
    const_iterator end() const {return constraints.end(); }

    const T* operator[](uint32_t i) const {
        return constraints[i];
    }

    T* operator[](uint32_t i) {
        return constraints[i];
    }

    bool isSatBy(const State* state) const {
        for (uint32_t i = 0; i < constraints.size(); i++) {
            if (!constraints[i]->isSatBy(state)) {
                return false;
            }
        }
        return true;
    }

    varset_t readVars() const {
        varset_t result;
        for (uint32_t i = 0; i < constraints.size(); i++) {
            result += constraints[i]->readVars();
        }
        return result;
    }

    varset_t writeVars() const {
        varset_t result;
        for (uint32_t i = 0; i < constraints.size(); i++) {
            result += constraints[i]->writeVars();
        }
        return result;
    }

    uint32_t size() const {
        return constraints.size();
    }

    bool empty() const {
        return constraints.empty();
    }

    std::ostream& display(std::ostream& o) const {
        return printPtrVector(constraints, " && ", o);
    }

    bool operator==(const Conjunction& c) const {
        std::set<T*> self, other;
        for (uint32_t i = 0; i < constraints.size(); i++) {
            self.insert(constraints[i]);
        }
        for (uint32_t i = 0; i < c.constraints.size(); i++) {
            other.insert(c.constraints[i]);
        }
        return self == other;
    }

    bool operator!=(const Conjunction& c) const {
        return !(*this == c);
    }

    bool constrain(State* state) const;
};

template<class T>
inline std::ostream& operator<<(std::ostream& o, const Conjunction<T>& c) {
    return c.display(o);
}

template<>
inline bool Conjunction<ClockConstraint>::constrain(State* state) const {
    for (uint32_t i = 0; i < constraints.size(); i++) {
        if (!constraints[i]->constrain(state)) {
            return false;
        }
    }
    return true;
}

#endif /* SYSTEM_CONSTRAINT_H */
