from docplex.mp.model import Model

DEBUG1 = not True
DEBUG2 = not True
DEBUG3 = not True

AUGUSTO_FACTOR = 1000000
EPSILON = 1

def print_debug1(x):
    if DEBUG1:
        print("DEBUG 1: "+str(x))
def print_debug2(x):
    if DEBUG2:
        print("DEBUG 2: "+str(x))
def print_debug3(x):
    if DEBUG3:
        print("DEBUG 3: "+str(x))

def normalize(vec):
    max = 0
    factor = 1
    for e in vec:
        if abs(e) > max:
            max = abs(e)
    if not max == 0:
        factor = 1 / max
    for i in range(len(vec)):
        vec[i] = vec[i] * factor




def find_vector(batches):  # batches is a list of lists of disjunctive action landmarks
    # e.g. [[g1,g2],[g3,g4,g5],[g6],[g1,g3,g7]]
    #
    # Let oi be the operator corresponding to gi:
    # In each plan the operators
    # (o1 OR o2) AND (o3 OR o4 OR o5) AND (o6) AND (o1 OR o3 OR o7)
    # are part of the plan

    eps = 0.0001
    mip = Model(name='mip')

    dim = len(batches[0][0])
    for a in batches[0]:
        if not len(a) == dim:
            print("ERROR: Vectors are of different dimensions")
            return


    idx_continuous = [i for i in range(dim)]
    c = mip.continuous_var_dict(lb=-1, ub=1, keys=idx_continuous)

    idx_binary = [tuple(batches[i][j]) for i in range(len(batches)) for j in range(len(batches[i]))]
    print(idx_binary)
    b = mip.binary_var_dict(idx_binary)
    for key in b:
        print(str(key)+"->"+str(b[key]))

    for i in range(len(batches)):
        print("batch"+str(i)+": "+str(batches[i]))
        mip.add_constraint(mip.sum(b[tuple(batches[i][j])] for j in range(len(batches[i]))) >= 1)
        for j in range(len(batches[i])):
            mip.add_if_then(b[tuple(batches[i][j])] == 1,
                            mip.sum(c[idx] * batches[i][j][idx] for idx in idx_continuous) <= -eps)

    sol = mip.solve()
    #print(sol)
    print("---")
    if sol == None:

        return False, []
    else:
        vec = []
        for i in range(mip.number_of_variables):
            if i < dim:
                vec.append(mip.get_var_by_index(i).solution_value)
            print(str(mip.get_var_by_index(i)) + " = " + str(mip.get_var_by_index(i).solution_value))
        normalize(vec)
        return True, vec


def find_refined_vector(batches, path_vector_batch, applicable_vector_batch, dim):

    eps = 0.0001
    mip = Model(name='mip')


    idx_continuous = [i for i in range(dim)]
    c = mip.continuous_var_dict(lb=-1, ub=1, keys=idx_continuous)

    idx_binary_1 = [tuple(["<"]+batches[i][j]) for i in range(len(batches)) for j in range(len(batches[i]))]
    print_debug2("dalm vectors:" + str(idx_binary_1))
    idx_binary_2 = [tuple([">="]+path_vector_batch[i]) for i in range(len(path_vector_batch))]
    print_debug2("path vectors: "+str(idx_binary_2))
    idx_binary_3 = [tuple(["<"]+applicable_vector_batch[i]) for i in range(len(applicable_vector_batch))]
    print_debug2("applicable vectors: "+str(idx_binary_3))
    idx_binary = idx_binary_1 + idx_binary_2 + idx_binary_3
    b = mip.binary_var_dict(idx_binary)
    for key in b:
        print_debug1(str(key)+"->"+str(b[key]))

    for i in range(len(batches)):
        print_debug2("dalm batch"+str(i)+": "+str(batches[i]))
        mip.add_constraint(mip.sum(b[tuple(["<"]+batches[i][j])] for j in range(len(batches[i]))) >= 1)
        for j in range(len(batches[i])):
            mip.add_if_then(b[tuple(["<"]+batches[i][j])] == 1,
                            mip.sum(c[idx] * batches[i][j][idx] for idx in idx_continuous) <= -eps)


    mip.add_constraint(mip.sum(b[tuple([">="]+path_vector_batch[i])] for i in range(len(path_vector_batch)))
                       +
                       mip.sum(b[tuple(["<"]+applicable_vector_batch[k])] for k in range(len(applicable_vector_batch)))
                       >= 1)
    for j in range(len(path_vector_batch)):
        mip.add_if_then(b[tuple([">="] + path_vector_batch[j])] == 1,
                        mip.sum(c[idx] * path_vector_batch[j][idx] for idx in idx_continuous) >= 0)
    for j in range(len(applicable_vector_batch)):
        mip.add_if_then(b[tuple(["<"] + applicable_vector_batch[j])] == 1,
                        mip.sum(c[idx] * applicable_vector_batch[j][idx] for idx in idx_continuous) <= -eps)


    sol = mip.solve()


    if sol == None:

        return False, []
    else:
        vec = []
        for i in range(mip.number_of_variables):
            if i < dim:
                vec.append(mip.get_var_by_index(i).solution_value)
            print_debug1(str(mip.get_var_by_index(i)) + " = " + str(mip.get_var_by_index(i).solution_value))
        normalize(vec)
        return True, vec


def create_indicator_var(vec, relation):
    return tuple([relation] + vec)


def create_binary_var_set(p_a_list, dalm_batch_collection):
    binary_key_set = set()
    for dalm_batch in dalm_batch_collection:
        for dalm in dalm_batch:
            new_b_key = create_indicator_var(dalm, "<")
            print_debug2(new_b_key)
            binary_key_set.add(new_b_key)
    for p_a_batch in p_a_list:
        for applicable_transition in p_a_batch[1]:
            new_b_key = create_indicator_var(applicable_transition, "<")
            binary_key_set.add(new_b_key)
        for path_transition in p_a_batch[0]:
            new_b_key = create_indicator_var(path_transition, ">=")
            binary_key_set.add(new_b_key)
    return binary_key_set


def create_dalm_constraint(dalm_batch, mip, b, c, dim):
    mip.add_constraint(mip.sum(b[tuple(["<"] + dalm_batch[j])] for j in range(len(dalm_batch))) >= 1)
    for j in range(len(dalm_batch)):
        mip.add_if_then(b[tuple(["<"] + dalm_batch[j])] == 1,
                        mip.sum(c[idx] * dalm_batch[j][idx] for idx in range(dim)) <= -EPSILON)

def create_dalm_batch_collection_constraints(dalm_batch_collection, mip, b, c, dim):
    for dalm_batch in dalm_batch_collection:
        create_dalm_constraint(dalm_batch, mip, b, c, dim)


def create_p_a_constraint(p_a_batch, mip, b, c, dim):
    print_debug3("p_a_batch: " + str(p_a_batch))
    path_vector_batch = p_a_batch[0]
    applicable_vector_batch = p_a_batch[1]
    mip.add_constraint(mip.sum(b[tuple([">="] + path_vector_batch[i])] for i in range(len(path_vector_batch)))
                       +
                       mip.sum(
                           b[tuple(["<"] + applicable_vector_batch[k])] for k in range(len(applicable_vector_batch)))
                       >= 1)
    for j in range(len(path_vector_batch)):
        mip.add_if_then(b[tuple([">="] + path_vector_batch[j])] == 1,
                        mip.sum(c[idx] * path_vector_batch[j][idx] for idx in range(dim)) >= 0)
    for j in range(len(applicable_vector_batch)):
        mip.add_if_then(b[tuple(["<"] + applicable_vector_batch[j])] == 1,
                        mip.sum(c[idx] * applicable_vector_batch[j][idx] for idx in range(dim)) <= -EPSILON)



def create_p_a_list_constraints(p_a_list, mip, b, c, dim):
    for p_a_batch in p_a_list:
        create_p_a_constraint(p_a_batch, mip, b, c, dim)


def find_refined_vector2(dalm_batch_collection, p_a_list, dim):

    mip = Model(name='mip')

    idx_continuous = [i for i in range(dim)]
    c = mip.continuous_var_dict(lb=-AUGUSTO_FACTOR, ub=AUGUSTO_FACTOR, keys=idx_continuous)
    #c = mip.integer_var_dict(lb=-AUGUSTO_FACTOR, ub=AUGUSTO_FACTOR, keys=idx_continuous)


    b_set = create_binary_var_set(p_a_list, dalm_batch_collection)
    b = mip.binary_var_dict(b_set)
    for key in b:
        print_debug1(str(key) + "->" + str(b[key]))

    create_dalm_batch_collection_constraints(dalm_batch_collection, mip, b, c, dim)
    create_p_a_list_constraints(p_a_list, mip, b, c, dim)
    print("start MIP solving...")
    sol = mip.solve()
    print("MIP finished")
    continuous_var_count, binary_var_count = dim, mip.number_of_variables-dim

    if sol is None:

        return False, [], continuous_var_count, binary_var_count
    else:
        vec = []
        for i in range(mip.number_of_variables):
            if i < dim:
                vec.append(mip.get_var_by_index(i).solution_value)
            else:
                print_debug1(str(mip.get_var_by_index(i)) + " = " + str(mip.get_var_by_index(i).solution_value))
        print_debug1("vec"+str(vec))
        #normalize(vec)
        #print("normalized"+str(vec))
        return True, vec, continuous_var_count, binary_var_count


if __name__ == '__main__':
    g1 = [-10.1, 10.1]
    g2 = [0, 10.2]
    g3 = [5.3, -10.3]
    g4 = [5.4, -10.4]
    g5 = [5.05, -10.05]
    g6 = [10.06, 10.06]
    g7 = [-1, -1]

    batch0 = [g1, g2]
    batch1 = [g3, g4, g5]
    batch2 = [g6]
    batch3 = [g7]  # [g1,g2]#

    f1 = [0, 1]
    f2 = [1, 0]

    batches = [batch0, batch1, batch2]  # ,[g7]]

    batches = [[[1, 0, 0], [-1, 0, 0], [0, 1, -1], [0, 1, 1]]]

    #batches = [[[0, 1, 0, 1]], [[0, 0, 0, -1]], [[0, -1, 0, -1], [0, -1, 0, -1], [0, 0, -1, -1], [-1, 0, 0, -1]],
     #          [[0, -1, 0, -1], [1, 0, 0, 1]], [[0, -1, 0, -1], [0, 0, 1, 1]]]

    #logistics example:
    # batches = [[[0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0]],
    #  [[0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 1, 0, 0, 0, 0, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 0, 0, 1, 0],
    #   [0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, -1, 0, 0, 1, 0]],
    #  [[0, 0, 0, 0, 1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]],
    #  [[1, -1, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0, 0]]]



    print("miconics Example:...")

    from_A_to_B = [-1,1,0]
    from_A_to_C = [-1,0,1]
    from_B_to_A = [1,-1,0]
    from_B_to_C = [0,-1,1]
    from_C_to_A = [1,0,-1]
    from_C_to_B = [0,1,-1]


    To_A = [from_B_to_A, from_C_to_A]
    To_B = [from_A_to_B, from_C_to_B]
    From_A = [from_A_to_B, from_A_to_C]
    From_B = [from_B_to_A, from_B_to_C]
    From_C = [from_C_to_A, from_C_to_B]

    batches = [To_A, To_B, From_A ,From_C] #MIP solvable
    batches = [To_A, To_B, From_A ,From_B] #MIP unsolvable



    path_vector_batch = [[1,1,1],[2,1,2]]
    applicable_vector_batch = [[1,2,3],[4,9,16]]
    solution_exists, solution_vector = find_refined_vector(batches, path_vector_batch, applicable_vector_batch)

    if not solution_exists:
        print("no solution -> dimension of potential heuristic is larger than expected")
    else:
        print("Solution found, these weights could maybe produce a potential heuristic of the expected dimension")
        print(solution_vector)

    solution_exists, solution_vector = find_vector(batches)
    if not solution_exists:
        print("no solution -> dimension of potential heuristic is larger than expected")
    else:
        print("Solution found, these weights could maybe produce a potential heuristic of the expected dimension")
        print(solution_vector)
