import os

DEBUG1 = not True
DEBUG2 = not True
DEBUG3 = not True


def print_debug1(x):
    if DEBUG1:
        print("DEBUG 1: "+str(x))
def print_debug2(x):
    if DEBUG2:
        print("DEBUG 2: "+str(x))
def print_debug3(x):
    if DEBUG3:
        print("DEBUG 3: "+str(x))



def get_dalm_idx_batch_collection(file_path):


    dalm_text = open(file_path)

    dalm_idx_batch_collection = []
    active = False
    for line in dalm_text:

        if "== End of Graph ==\n" == line:
            active = False

        if active and (not line == ">\n") and (not line == "<\n") and (not line[0] == "#"):

            dalm_idx_batch = list(map(int, line.strip().split(' ')))

            dalm_idx_batch_collection.append(dalm_idx_batch)
        if "== Disjunctive Action Landmark Graph ==\n" == line:
            active = True

    print_debug3(dalm_idx_batch_collection)
    return  dalm_idx_batch_collection

def get_domain_list(sas_file_path):
    sas_dictionary = dict()
    operator_count = 0
    op_dictionary = dict()
    previous_size = 0
    axiom_detected = False
    conditional_effect_detected = False
    underdefined_precondtion_detected = False
    with open(sas_file_path, "r") as sas_text:
        sas_text_list=sas_text.read().split("\n")
        for idx in range(len(sas_text_list)):
            #print(sas_text_list[idx])
            if sas_text_list[idx] == "begin_variable":

                sas_dictionary[int(sas_text_list[idx+1][3:])]= (previous_size ,int(sas_text_list[idx+3]))
                previous_size = previous_size + int(sas_text_list[idx + 3])
                if not sas_text_list[idx+2] == "-1":
                    for i in range(5):
                        print(sas_text_list[idx+i])
                    axiom_detected = True
                    print("ERROR: Axiom detected")
                    break
            if sas_text_list[idx] == "begin_operator":
                operator_name = sas_text_list[idx+1]
                prevail_condition_amount = int(sas_text_list[idx+2])
                effect_amount = int(sas_text_list[idx+2+prevail_condition_amount +1])

                affected_dimensions = set()
                for i in range(effect_amount):
                    #print(sas_text_list[idx+2+prevail_condition_amount +1 +1 +i])
                    bi_atomar_effect=list(map(int, sas_text_list[idx+2+prevail_condition_amount +1 +1 +i].split(' ')))
                    #print(bi_atomar_effect)
                    if not bi_atomar_effect[0]==0:
                        conditional_effect_detected = True
                        print("ERROR: Conditional effect detected")
                        break
                    if bi_atomar_effect[2] == -1:

                        domain_size = sas_dictionary[bi_atomar_effect[1]][1]
                        dim1 = sas_dictionary[bi_atomar_effect[1]][0]
                        dim2 = sas_dictionary[bi_atomar_effect[1]][0] + bi_atomar_effect[3]
                        affected_dimensions.add((dim1, ('underdefined Precondition',domain_size)))
                        affected_dimensions.add((dim2, 1))
                    else:
                        dim1 = sas_dictionary[ bi_atomar_effect[1] ][0]+bi_atomar_effect[2]
                        dim2 = sas_dictionary[ bi_atomar_effect[1] ][0]+bi_atomar_effect[3]
                        affected_dimensions.add((dim1,-1))
                        affected_dimensions.add((dim2,1))
                op_dictionary[operator_count]=(operator_name, affected_dimensions)
                operator_count+=1


    return sas_dictionary, axiom_detected, conditional_effect_detected, underdefined_precondtion_detected ,op_dictionary,previous_size


def factor_out(vector, underdefined_perconditions):
    if underdefined_perconditions==[]:
        return [vector]
    mini_batch=[]
    mini_batch_size = 1
    for precondition in underdefined_perconditions:
        mini_batch_size *= precondition[1]
    for c in range(mini_batch_size):
        new_vector = vector.copy()
        radix = [0]*len(underdefined_perconditions)
        remain = c
        for precondition in underdefined_perconditions:
            radix = remain % precondition[1]
            remain = (remain - (remain % precondition[1])) / precondition[1]

            remain = int(remain)
            new_vector[ radix + precondition[0] ] = new_vector[ radix + precondition[0] ]-1
        mini_batch.append(new_vector)
    print_debug3("original vec:"+str(vector))
    print_debug3("u-definend preCond:"+str(underdefined_perconditions))
    for vec in mini_batch:
        print_debug3(vec)
    return mini_batch



def get_vector_batch_collection(op_dictionary,dalm_idx_batch_collection, dimension_count):
    vector_batch_collection=[]

    for dalm_idx_batch in dalm_idx_batch_collection:
        vector_batch = []
        for dalm_idx in dalm_idx_batch:

            vector = [0]*dimension_count
            underdefined_perconditions = []

            for atomar_effect in op_dictionary[dalm_idx][1]:

                if not (atomar_effect[1] == -1 or atomar_effect[1] == 0 or atomar_effect[1] == 1):
                    underdefined_perconditions.append(  (atomar_effect[0], atomar_effect[1][1])  )
                else:
                    vector[atomar_effect[0]] = atomar_effect[1]

            factored_vectors = factor_out(vector,underdefined_perconditions)
            vector_batch += factored_vectors
        vector_batch_collection.append(vector_batch)

    return vector_batch_collection



def get_path_idx_batch(loc_result_path):
    loc_result_text = open(loc_result_path)
    path_idx_batch = []
    active = False
    for line in loc_result_text:

        if "== end path ==" in line:
            active = False

        if active:
            words = line.split(' ')

            path_idx_batch.append(int(words[-1][3:]))
        if "== begin path ==" in line:
            active = True

    return path_idx_batch

def get_applicable_idx_batch(loc_result_path):
    loc_result_text = open(loc_result_path)
    applicable_idx_batch = []
    active = False
    for line in loc_result_text:

        if "== end of applicable operators from local optimum ==" in line:
            active = False

        if active:
            words = line.split(' ')

            applicable_idx_batch.append(int(words[-1][3:]))
        if "== applicable operators from local optimum ==" in line:
            active = True


    return applicable_idx_batch





def extract_vector_batch_collection(op_idx_batch_collection, sas_path):
    sas_dictionary, axiom_detected, conditional_effect_detected, underdefined_precondtion_detected, op_dictionary, dimension_count = get_domain_list(
        sas_path)

    complain_list = []
    if axiom_detected:
        print("ERROR: Task uses Axioms")
        complain_list.append("Task uses Axioms")
    elif conditional_effect_detected:
        print("ERROR: Task uses Conditional Effects")
        complain_list.append("Task uses Conditional Effects")
    elif underdefined_precondtion_detected:
        print("ERROR: Task has underdefined Precondition :(")
        complain_list.append("Task has underdefined Precondition")
    else:
        print_debug3("looks fine")

    print_debug2("#dimensions:"+str(dimension_count))
    print_debug3("var_dict:" + str(sas_dictionary))
    print_debug3("op_dict:" + str(op_dictionary))
    print_debug3("op_idx_batch_collection:" + str(op_idx_batch_collection))
    if not complain_list == []:
        return [], complain_list

    vector_batch_collection = get_vector_batch_collection(op_dictionary, op_idx_batch_collection, dimension_count)# factorizes vectors
    for v in vector_batch_collection:
        print_debug3(v)
    return vector_batch_collection, complain_list

def extract_dalm_vector_batch_collection(dalm_path, sas_path):
    dalm_idx_batch_collection = get_dalm_idx_batch_collection(dalm_path)
    dalm_vector_batch_collection, complain_list = extract_vector_batch_collection(dalm_idx_batch_collection, sas_path)
    return dalm_vector_batch_collection, complain_list

def extract_path_vector_batch(loc_result_path, sas_path):
    #factoring out not allowed
    path_idx_batch = get_path_idx_batch(loc_result_path)
    path_vector_batch_collection, complain_list = extract_vector_batch_collection([path_idx_batch], sas_path)# factorizes vectors
    path_vector_batch = path_vector_batch_collection[0]
    return path_vector_batch, complain_list

def extract_applicable_vector_batch(loc_result_path, sas_path):
    applicable_idx_batch = get_applicable_idx_batch(loc_result_path)
    applicable_vector_batch_collection, complain_list = extract_vector_batch_collection([applicable_idx_batch], sas_path)
    applicable_vector_batch = applicable_vector_batch_collection[0]
    return applicable_vector_batch, complain_list


def get_initial_vec(sas_file_path, sas_dictionary, dimension_count):
    vec = [0]*dimension_count
    with open(sas_file_path, "r") as sas_text:
        sas_text_list = sas_text.read().split("\n")
        for idx in range(len(sas_text_list)):
            if sas_text_list[idx] == "begin_state":
                for i in range(len(sas_dictionary)):
                    domain_start_idx = sas_dictionary[i][0]
                    fact_idx = domain_start_idx + int(sas_text_list[idx+1+i])
                    vec[fact_idx] = 1
                break
    return vec

def get_next_vec_by_op_id(current_vec, op_dictionary, op_idx):
    new_vec = current_vec.copy()
    op = op_dictionary[op_idx]
    effect_set = op[1]
    for atomic_effect in effect_set:
        if not (atomic_effect[1] == 1 or atomic_effect[1] == -1):
            domain_start = atomic_effect[0]
            domain_range = atomic_effect[1][1]
            for vec_idx in range(domain_start, domain_start + domain_range):
                new_vec[vec_idx] = 0
    for atomic_effect in effect_set:
        if atomic_effect[1] == 1 or atomic_effect[1] == -1:
            vec_idx = atomic_effect[0]
            new_vec[vec_idx] = new_vec[vec_idx] + atomic_effect[1]

    return new_vec

def get_next_vec(current_vec, step, op_dictionary, path_idx_batch):
    op_idx = path_idx_batch[step]
    return get_next_vec_by_op_id(current_vec, op_dictionary, op_idx)

def get_applicable_state_vecs(current_vec, op_dictionary, applicable_idx_batch):
    applicable_state_ves=[]
    for app_op_idx in applicable_idx_batch:
        new_state_vec = get_next_vec_by_op_id(current_vec, op_dictionary, app_op_idx)
        applicable_state_ves.append(new_state_vec)
    return applicable_state_ves

def extract_loc_vector_batch(loc_result_path, sas_path):
    sas_dictionary, axiom_detected, conditional_effect_detected, underdefined_precondtion_detected, op_dictionary, dimension_count = get_domain_list(
        sas_path)

    complain_list = []
    if axiom_detected:
        print("ERROR: Task uses Axioms")
        complain_list.append("Task uses Axioms")
    elif conditional_effect_detected:
        print("ERROR: Task uses Conditional Effects")
        complain_list.append("Task uses Conditional Effects")
    elif underdefined_precondtion_detected:
        print("ERROR: Task has underdefined Precondition :(")
        complain_list.append("Task has underdefined Precondition")
    else:
        print_debug3("looks fine")

    #print("#dimensions:" + str(dimension_count))
    #print("var_dict:" + str(sas_dictionary))
    #print("op_dict:" + str(op_dictionary))


    if not complain_list == []:
        return complain_list

    path_idx_batch = get_path_idx_batch(loc_result_path)
    print_debug1("path_idx_batch: "+str(path_idx_batch))
    applicable_idx_batch = get_applicable_idx_batch(loc_result_path)
    print_debug1("applicable_idx_batch: "+str(applicable_idx_batch))
    current_vec = get_initial_vec(sas_path, sas_dictionary, dimension_count)
    vec_trace = [current_vec]
    print_debug2("cur_vec: "+str(current_vec))
    for step in range(0,len(path_idx_batch)):

        current_vec = get_next_vec(current_vec, step, op_dictionary, path_idx_batch)
        print_debug2("cur_vec: " + str(current_vec))
        vec_trace.append(current_vec)
    print_debug2("vec_trace" + str(vec_trace))
    local_optimum_vec = current_vec
    applicable_state_vecs = get_applicable_state_vecs(local_optimum_vec, op_dictionary, applicable_idx_batch)
    print_debug2("applicable_state_vecs" + str(applicable_state_vecs))

    path_transitions = []
    for idx in range(1,len(vec_trace)):
        transition = [0] * len(vec_trace[idx])
        for i in range(len(transition)):
            transition[i] = vec_trace[idx][i] - vec_trace[idx-1][i]
        path_transitions.append(transition)
    app_transitions = []
    for app_state_vec in applicable_state_vecs:
        transition = [0] * len(app_state_vec)
        for i in range(len(transition)):
            transition[i] = app_state_vec[i] - local_optimum_vec[i]
        app_transitions.append(transition)
    return path_transitions, app_transitions, complain_list, dimension_count



if __name__ == '__main__':
    dalm_vector_batch_collection, complaint_list = extract_dalm_vector_batch_collection('/home/simon/uni/masterthesis/MA/master-thesis-SDold/Code/downward/output.dalm',
                                        '/home/simon/uni/masterthesis/MA/master-thesis-SDold/Code/downward/output.sas')
    print(dalm_vector_batch_collection)

    print("xxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxxx")



    path_vector_batch, applicable_vector_batch, complain_list = extract_loc_vector_batch(
        '/home/simon/uni/masterthesis/MA/master-thesis-SDold/Code/downward/los_result.txt',
        '/home/simon/uni/masterthesis/MA/master-thesis-SDold/Code/downward/output.sas')
    print(path_vector_batch)
    print(applicable_vector_batch)
    print("VECTORIZER.py WAS CALLED AS MAIN ???")