import json

import pandas
import pandas as pd
import seaborn as sns
from matplotlib import pyplot as plt, pyplot
import tikzplotlib


def list_to_string(li):
    text = ""
    for el in li:
        text += el
    return text


def add_identity(axes, *line_args, **line_kwargs):
    identity, = axes.plot([], [], *line_args, **line_kwargs)

    def callback(axes):
        low_x, high_x = axes.get_xlim()
        low_y, high_y = axes.get_ylim()
        low = max(low_x, low_y)
        high = min(high_x, high_y)
        identity.set_data([low, high], [low, high])

    callback(axes)
    axes.callbacks.connect('xlim_changed', callback)
    axes.callbacks.connect('ylim_changed', callback)
    return axes


def reindex_by_hand(path):
    with open(path) as f:
        json_full = json.load(f)
    json_mod = {"helm": dict(), "rin": dict()}
    for key in json_full:
        element = json_full[key]
        element['build_options'] = list_to_string(element['build_options'])
        element['component_options'] = list_to_string(element['component_options'])
        element['driver_options'] = list_to_string(element['driver_options'])
        element['id'] = list_to_string(element['id'])
        element['initial_h_values'] = None
        if 'helm' in element["algorithm"]:
            json_mod["helm"][key] = element
        else:  # 'rin' in algo
            json_mod["rin"][key] = element
    return json_mod


def get_label(min_wins_text, name, wins):
    return name + " (" + min_wins_text + " for " + str(wins) + " tasks)"


def do_scatter_plot(df, x_name, y_name, attribute, min_wins_text="lower"):
    h = df.loc[x_name, attribute]
    r = df.loc[y_name, attribute]
    df_scatter = pd.DataFrame(zip(h, r), columns=[x_name, y_name])
    if min_wins_text == "lower":
        x_wins = df_scatter.loc(axis=1)[x_name].lt(df_scatter.loc[:, y_name], axis=0).sum()
        y_wins = df_scatter.loc(axis=1)[y_name].lt(df_scatter.loc[:, x_name], axis=0).sum()
    else:
        assert min_wins_text == "greater"
        x_wins = df_scatter.loc(axis=1)[x_name].gt(df_scatter.loc[:, y_name], axis=0).sum()
        y_wins = df_scatter.loc(axis=1)[y_name].gt(df_scatter.loc[:, x_name], axis=0).sum()
    total_helper = df_scatter.apply(lambda x: x.isnull().any(), axis=1)
    total = len(total_helper[total_helper == False].index)
    ax = df_scatter.plot.scatter(x=x_name, y=y_name, logy=False, c="Black", marker="x")
    add_identity(ax, color='gray', ls='-')
    # plt.show()
    plt.ylabel(get_label(min_wins_text, y_name, y_wins))
    plt.xlabel(get_label(min_wins_text, x_name, x_wins))
    plt.suptitle(attribute.replace("_", " ") + " (" + str(total) + " tasks)")
    plt.savefig('scatter-plot-' + attribute + '.pdf', format='pdf')


def rename_columns(df, id_to_origin_target, introduce_new_lvl=False):
    old_index = df.T.index.to_list()
    index = dict()
    for id in id_to_origin_target:
        old_index_names = [n for n in old_index if id in n]
        if introduce_new_lvl:
            higher_lvl = id_to_origin_target[id][0]
            replacement = id_to_origin_target[id][1]
            index.update({k: (higher_lvl, k.replace(higher_lvl + "-", replacement)) for k in old_index_names})
        else:
            original = id_to_origin_target[id][0]
            replacement = id_to_origin_target[id][1]
            index.update({k: k.replace(original + "-", replacement) for k in old_index_names})
    renamed = df.rename(columns=index)
    if introduce_new_lvl:
        index = pd.MultiIndex.from_tuples(list(index.values()))
        renamed = renamed.T.reindex(index)
    return renamed


def do_regplot(x_name, y_name, data, outfile_name=None, x_log=True, y_log=True, show=True, save=True, tex=False):
    axis = sns.regplot(x=x_name, y=y_name, data=data, fit_reg=False)  # lmplot   hue="algorithm",
    if x_log:
        axis.set_xscale('log')
    if y_log:
        axis.set_yscale('log')
    if save:
        if tex:
            if outfile_name:
                tikzplotlib.save(outfile_name + '.tex')
            else:
                tikzplotlib.save('plot-' + x_name + "-" + y_name + '.tex')
        else:
            if outfile_name:
                plt.savefig(outfile_name + '.pdf', format='pdf')
            else:
                plt.savefig('plot-' + x_name + "-" + y_name + '.pdf', format='pdf')
    if show:
        plt.show()


def do_regplot_from_dict(key_to_argument, data):
    outfile_name = key_to_argument["outfile_name"] if "outfile_name" in key_to_argument else None
    x_log = key_to_argument["x_log"] if "x_log" in key_to_argument else True
    y_log = key_to_argument["y_log"] if "y_log" in key_to_argument else True
    show = key_to_argument["show"] if "show" in key_to_argument else True
    save = key_to_argument["save"] if "save" in key_to_argument else True
    tex = key_to_argument["tex"] if "tex" in key_to_argument else False
    do_regplot(key_to_argument["x_name"], key_to_argument["y_name"], data, outfile_name, x_log, y_log, show, save, tex)
