#include "subsumption_trie.h"

#include "globals.h"

#include <climits>
using namespace std;

TrieNode::TrieNode(size_t size)
    : children(vector<TrieNode *>(size, nullptr)),
      id(StateID::no_state){}

TrieNode::~TrieNode() {}
TrieNode* TrieNode::get_child(size_t index) {
    return children[index];
}

void TrieNode::insert_child(size_t index, TrieNode* child) {
    children[index] = child;
}
void TrieNode::set_id(StateID id) {
    this->id = id;
}

StateID TrieNode::get_id() {
    return id;
}
size_t TrieNode::get_size() {
    return children.size();
}

SubsumptionTrie::SubsumptionTrie()
    : root(new TrieNode(g_variable_domain[0])){}


SubsumptionTrie::~SubsumptionTrie() {
    delete_node(root);
    delete(root);
}

void SubsumptionTrie::insert(const GlobalState& state) {
    insert(state, root, 0);
}

void SubsumptionTrie::insert(const GlobalState& state, TrieNode* node,
                       size_t level) {
    bool at_leaf = level == g_variable_domain.size() -1;
    TrieNode *child = node->get_child(state[level]);
    if (!child) {
        size_t child_size;
        if (!at_leaf) {
            child_size = g_variable_domain[level + 1];
        } else {
            child_size = 0;
        }
        TrieNode *new_node = new TrieNode(child_size);
        child = new_node;
        node->insert_child(state[level], child);
        if (at_leaf) {
            new_node->set_id(state.get_id());
        }
    }
    if (!at_leaf) {
        insert(state, child, level + 1);
    } else {
        return;
    }
}

void SubsumptionTrie::delete_node(TrieNode* node) {
    for (size_t i = 0; i < node->get_size(); i++) {
        TrieNode* child = node->get_child(i);
        delete_node(child);
        delete(child);
    }
}

bool SubsumptionTrie::lookup(const GlobalState& state) {
    bool subsuming_found = false;
    lookup(state, root, 0, subsuming_found);
    return subsuming_found;
}

void SubsumptionTrie::lookup(const GlobalState& state, TrieNode* node,
                             size_t level, bool& subsuming_found) {
    bool at_leaf = level == g_variable_domain.size() - 1;
    vector<TrieNode*> children;
    children.push_back(node->get_child(g_variable_domain[level] - 1));
    if (state[level] != g_variable_domain[level] - 1) {
        children.push_back(node->get_child(state[level]));
    }
    for (size_t i = 0; i < children.size(); i++) {
        if (children[i]) {
            if (at_leaf) {
                SearchNode state_node = search_space->get_node(state);
                SearchNode state_node_subsuming
                    = search_space->get_node(
                        g_state_registry->lookup_state(
                            children[i]->get_id()));
                if (state_node.get_g() >= state_node_subsuming.get_g()) {
                    subsuming_found = true;
                    return;
                }
            } else {
                lookup(state, children[i], level + 1, subsuming_found);
            }
        }
    }
}

void SubsumptionTrie::attach_search_space(SearchSpace* search_space) {
    this->search_space = search_space;
}
