# temp
from pprint import pprint
import itertools
import subprocess
import statistics

from matplotlib import cm, pyplot
from lab.reports import Report
from downward.reports.scatter_matplotlib import ScatterMatplotlib
from downward.reports.scatter_pgfplots import ScatterPgfplots

def ID(val):
    return val

def fmt_domain(domain):
    if domain == 'default':
        return "\\ddefault{}"
    elif domain == 'no-derived':
        return "\\dnd{}"
    elif domain == "ordered":
        return "\\dordered{}"
    elif domain == "no-derived-ordered":
        return "\\dndo{}"
    else: 
        print("Unknown domain", domain)
        exit(1)

def fmt_alg(alg):
    name = alg.split('-', 1)
    return name[1] if len(name) > 1 else alg



class TxtReport(Report):
    def __init__(self, **kwargs):
        Report.__init__(self, **kwargs)
        self.attribute = self.attributes[0]

    def get_txt(self):
        raise Exception('Implement me')

    def write(self):
        with open(self.outfile, "w+") as f:
            f.write(self.get_txt())


class MEAN_TIME_REPORT(TxtReport):
    def __init__(self, **kwargs):
        TxtReport.__init__(self, **kwargs)

        pass

    def get_data(self):
        group = {}
        
        for key, run in self.props.items():
            if not run["algorithm"] in group:
                group[run["algorithm"]] = {}
            if not run["domain"] in group[run["algorithm"]]:
                group[run["algorithm"]][run["domain"]] = []
            if "total_time" in run:
                group[run["algorithm"]][run["domain"]].append(run["total_time"])
        result = {}
        for domain, data in group.items():
            for alg, d in data.items():
                if not domain in result:
                    result[domain] = {}
                
                result[domain][alg] = statistics.mean(d)
        return result

    def get_txt(self):
        data = self.get_data()

        keys = set()
        for cat, c2 in data.items():
            for k in c2.keys():
                keys.add(k)
        keys = list(keys)

        string = "\\begin{tabular}{l|" + (" ".join("c" * (len(keys)))) + "}\n"

        if True or self.category_attr == "domain":
            keys2 = []
            for k in keys:
                keys2.append(fmt_domain(k))
        else:
            keys2 = keys

        # for cat, c2 in data.items():
        string += " & " + (" & ".join(keys2)) + " \\\\ \n\\hline \n"
        for cat, c2 in data.items():
            values = [fmt_alg(str(cat))]
            for key in keys:
                if key in c2:
                    amount = c2[key]
                else:
                    amount = 0
                values.append(f"{amount:.2f}")
                # print(key, c2)
            string += f"{ (' & '.join(values)) }\\\\ \n"
        string += "\\end{tabular}"
        return string



class AggregateReport(TxtReport):
    def __init__(self, category_attr="domain", acc_func=statistics.mean, med_func=ID, **kwargs):
        TxtReport.__init__(self, **kwargs)
        self.category_attr = category_attr
        self.acc_func = acc_func
        self.med_func = med_func


    def _group_runs(self):
        groups = {}
        categories = set()
        for key, run in self.props.items():
            gkey = (run["algorithm"], run["problem"])
            if not gkey in groups:
                groups[gkey] = {}
            groups[gkey][run[self.category_attr]] = run
            categories.add(run[self.category_attr])
        return (groups, list(categories))

    def _extract_attrs(self, groups, suits):
        attr = self.attribute

        # categroy: [(suite1.attr, suite2.attr)]
        attr_group = []

        for key in groups:
            g = groups[key]
            
            skip = False
            for s in suits:
                if not s in g:
                    skip = True
                    break
            if skip:
                continue
            values = []
            for s in suits:
                values.append(g[s].get(attr, None))

            if not None in values:
                attr_group.append(values)

        return attr_group

    def get_txt(self):
        group, suites = self._group_runs()
        attrs = self._extract_attrs(group, suites)
        regroup = {}
        for vals in attrs:
            for i, s in enumerate(suites):
                if not s in regroup:
                    regroup[s] = []
                regroup[s].append(vals[i])
        
        agg = {}
        for key, val in regroup.items():
            agg[key] = self.acc_func(val)
        print(agg)
        string = str(self.med_func(agg))
        return string

class LatexTable(TxtReport):
    def __init__(self, category_attr="domain", categorize=ID, **kwargs):
        TxtReport.__init__(self, **kwargs)

        assert len(self.attributes) == 1
        self.category_attr = category_attr
        self.acc = categorize
        pass
    
    def get_data(self):
        group = {}
        for key, run in self.props.items():
            if not run[self.category_attr] in group:
                group[run[self.category_attr]] = []
            group[run[self.category_attr]].append(run[self.attribute])
        result = {}
        for cat, data in group.items():
            for d in data:
                c2 = cat
                c1 = self.acc(d)
                if not c1 in result:
                    result[c1] = {}
                if not c2 in result[c1]:
                    result[c1][c2] = 0
                result[c1][c2] += 1
        return result

    def get_txt(self):
        data = self.get_data()
        
        
        keys = set()
        for cat, c2 in data.items():
            for k in c2.keys():
                keys.add(k)
        keys = list(keys)
        
        string = "\\begin{tabular}{l|" + (" ".join("c" * (len(keys)))) + "}\n"
        
        if self.category_attr == "domain":
            keys2 = []
            for k in keys:
                keys2.append(fmt_domain(k))
        else:
            keys2 = keys    
        
        # for cat, c2 in data.items():
        string += " & " + (" & ".join(keys2)) + " \\\\ \n\\hline \n"
        for cat, c2 in data.items():
            values = [str(cat)]
            for key in keys:
                if key in c2:
                    amount = c2[key]
                else:
                    amount = 0
                values.append(f"{amount}")
                # print(key, c2)
            string += f"{ (' & '.join(values)) }\\\\ \n"
        string += "\\end{tabular}"
        return string


class BarChartReport(Report):
    def __init__(self,  acc_func=sum, fmt_alg=None, **kwargs):
        Report.__init__(self, **kwargs)
        
        self.attribute = self._get_attr()
        self.acc = acc_func

        self.title = "Amount Solved"
        self.label = "Solved"
        self.fmt_alg = fmt_alg

    def _categorize(self):
        
        group = {}

        for key, run in self.props.items():
            if not run['domain'] in group:
                group[run['domain']] = []
            group[run['domain']].append(run[self.attribute])
        return group
    
    def _accumulate(self, cats):
        result = []
        for key, val in cats.items():
            result.append((key, self.acc(val)))
        return result

    def _create_plot(self,categroies):
        height = [x for alg, (x, total) in categroies]
        total = [total for alg, (x, total) in categroies]
        algs = []
        for alg, x in categroies:
            if self.fmt_alg is None:
                a = alg
            else:
                a = self.fmt_alg(alg)
            algs.append(a)


        print(height)
        print(total)
        print(algs)

        rng = range(len(categroies))

        pyplot.figure()
        pt = pyplot.bar(rng, total)
        ph = pyplot.bar(rng, height)

        #pyplot.ylabel('Scores')


        pyplot.title(self.title)
        pyplot.xticks(rng, algs, rotation='vertical')
        pyplot.legend((ph[0], pt[0]), (self.label, 'Total'))
        pyplot.tight_layout()
        pyplot.savefig(self.outfile)

        #pyplot.close()


    def write(self):
        cats = self._categorize()
        cats = self._accumulate(cats)
        self._create_plot(cats)

    def _get_attr(self):
        assert len(self.attributes) == 1
        return self.attributes[0]

class SuiteScatterPlotReport(Report):
    def __init__(self, suites = None, category=None, markers="discrete", scale='linear', relative=False, fmt_cat=ID, **kwargs):
        Report.__init__(self, **kwargs)
        
        self.matplotlib_options = {
            "figure.figsize": [4, 4],
        }

        self.other_sizes = [2,3,6]
        
        self.suites = suites
        # based on what property sould the shape/color of the points be determined
        self.category = category
        self.relative = relative
        # how should the marker color/shape be determined
        # discrete: every marker gets a random shape/color
        # range: 
        self.markers = markers
        self.scale = scale # 'linear', 'log', 'symlog'
        self.xscale = self.scale
        self.yscale = self.scale
        self.yscale = "log" if self.relative else self.xscale
        self.xlabel, self.ylabel = self._get_labels(suites)
        self.attribute = self._get_attr()
        self.title = self.attribute or ""
        self.plot_horizontal_line = False
        self.plot_diagonal_line = True
        self.show_missing = True
        self.fmt = fmt_cat

        self.writer = MatPlotLibMod

    def _get_labels(self, suites):
        assert not suites is None and len(suites) == 2
        return (fmt_domain(suites[0]), fmt_domain(suites[1]))

    def _get_attr(self):
        assert len(self.attributes) == 1
        return self.attributes[0]

    def _turn_into_relative_coords(self, categories):
        assert self.relative
        y_rel_max = 0
        for coords in categories.values():
            for x, y in coords:
                if (x is not None and x <= 0) or (y is not None and y <= 0):
                    x = max(1e-5, x)
                    y = max(1e-5, x)
                    #logging.critical("Relative scatter plots need values > 0.")
                if x is not None and y is not None:
                    y_rel_max = max(y_rel_max, y / float(x))
        y_rel_missing = y_rel_max * 1.5 if y_rel_max != 0 else None
        x_missing = self._compute_missing_value(categories, 0, self.xscale)
        self.x_upper = x_missing
        self.y_upper = y_rel_missing

        new_categories = {}
        for category, coords in categories.items():
            new_coords = []
            for coord in coords:
                x, y = coord
                if x is None and y is None:
                    x, y = x_missing, y_rel_missing
                elif x is None and y is not None:
                    x, y = x_missing, 1
                elif x is not None and y is None:
                    x, y = x, y_rel_missing
                elif x is not None and y is not None:
                    x, y = x, y / float(x)
                new_coords.append((x, y))
            if new_coords:
                new_categories[category] = new_coords
        return new_categories
    
    def _compute_missing_value(self, categories, axis, scale):
        if not self.show_missing:
            return None
        values = [coord[axis] for coords in categories.values()
                  for coord in coords]
        real_values = [value for value in values if value is not None]
        if len(real_values) == len(values):
            # The list doesn't contain None values.
            return None
        if not real_values:
            return 1
        max_value = max(real_values)
        if scale == "linear":
            return max_value * 1.1
        return int(10 ** math.ceil(math.log10(max_value)))

    # Group the runs of two suites together
    # so we can use suite_1 on x axis and suite_2 on y axis
    def _group_runs(self):
        groups = {}
        
        # {
        #   key: {...all attrs}
        # }
        all_items = self.props

        for key, run in all_items.items():
            gkey = (run["algorithm"], run["problem"])
            if not gkey in groups:
                groups[gkey] = {}
            groups[gkey][run['domain']] = run
        return groups
    
    
    def _extract_attrs(self, groups):
        attr = self.attribute
        
        # categroy: [(suite1.attr, suite2.attr)]
        attr_group = { }

        
        for key in groups:
            g = groups[key]
            if not self.suites[0] in g or not self.suites[1] in g:
                continue
            x = g[self.suites[0]].get(attr, None)
            y = g[self.suites[1]].get(attr, None)
            
            cat = g[self.suites[0]].get(self.category, None)
            cat = self.fmt(cat)
            
            if not cat in attr_group:
                attr_group[cat] = []

            if not None in (x,y):
                attr_group[cat].append((x,y))
            
        return attr_group

    def _get_category_styles_discrete(self, categories):
        shapes = "x+os^v<>D"
        colors = [f"C{c}" for c in range(10)]

        num_styles = len(shapes) * len(colors)
        styles = [
            {"marker": shape, "c": color}
            for shape, color in itertools.islice(
                zip(itertools.cycle(shapes), itertools.cycle(colors)), num_styles
            )
        ]

        category_styles = {}
        for i, category in enumerate(sorted(categories)):
            category_styles[category] = styles[i % len(styles)]

        return category_styles

    def _get_category_styles_range(self, categories):
        viridis = cm.get_cmap('cool', len(categories))
        shape = "x"

        cats = categories.keys()
        minc = min(cats)
        maxc = max(cats)

        styles = {}
        for category in cats:
            color = viridis((category - minc)/maxc)
            styles[category] = {"marker": shape, "color": color}

        return styles

    def _get_bounds(self, categories):
        xMax = 0
        yMax = 0
        xMin = 1e10
        yMin = 1e10
        for k, vals in categories.items():
            for v in vals:
                xMax = v[0] if not v[0] is None and xMax < v[0] else xMax
                yMax = v[1] if not v[1] is None and yMax < v[1] else yMax
                xMin = v[0] if not v[0] is None and xMin > v[0] else xMin
                yMin = v[1] if not v[1] is None and yMin > v[1] else yMin
        mx = max(xMax,yMax)
        mi = min(xMin, yMin)
        return (mx + 0.05 * mx,mx + 0.05 * mx, None, None)

    def has_multiple_categories(self):
        return any(key is not None for key in self.categories.keys())

    def write(self):
        groups = self._group_runs()
        self.categories = self._extract_attrs(groups)
        if (self.markers == 'discrete'):
            self.styles = self._get_category_styles_discrete(self.categories)
        elif (self.markers == 'range'):
            self.styles = self._get_category_styles_range(self.categories)
        self.x_upper, self.y_upper, self.x_lower, self.y_lower = self._get_bounds(
            self.categories)
        
        if self.relative:
            self.plot_diagonal_line = False
            self.plot_horizontal_line = True
            self.categories = self._turn_into_relative_coords(self.categories)
        
        self.writer.write(self, self.outfile)
        
        def _get_axis_options(cls, report):
            axis = {}
            axis["xlabel"] = report.xlabel
            axis["ylabel"] = report.ylabel
            axis["title"] = report.title
            axis["legend cell align"] = "left"

            convert_scale = {"log": "log", "symlog": "log", "linear": "normal"}
            axis["xmode"] = convert_scale[report.xscale]
            axis["ymode"] = convert_scale[report.yscale]

            # Plot size is set in inches.
            figsize = report.matplotlib_options.get("figure.figsize")
            if figsize:
                width, height = figsize
                axis["width"] = f"{width:.2f}in"
                axis["height"] = f"{height:.2f}in"

            if report.has_multiple_categories():
                axis["legend style"] = cls._format_options(
                    {"legend pos": "south west"}
                )

            return axis
        
        tmpsize = self.matplotlib_options["figure.figsize"]
        self.matplotlib_options["figure.figsize"] = [3.5, 3.5]

        ScatterPgfplots._get_axis_options = _get_axis_options.__get__(
            ScatterPgfplots, ScatterPgfplots)
        ScatterPgfplots.write(self, self.outfile.replace('png', 'tex'))
        

        self.matplotlib_options["figure.figsize"] = tmpsize
        


class MatPlotLibMod(ScatterMatplotlib):
    @classmethod
    def _plot(cls, report, axes):
        axes.grid(b=True, linestyle="-", color="0.75")

        for category, coords in sorted(report.categories.items()):
            x_vals, y_vals = zip(*coords)
            axes.scatter(
                x_vals, y_vals, clip_on=False, label=category, **report.styles[category]
            )
        #axes.set_aspect(1, anchor='C')
        axes.set_xbound(lower=report.x_lower, upper=report.x_upper)
        axes.set_ybound(lower=report.y_lower, upper=report.y_upper)
