from collections import namedtuple
import getpass
from pathlib import Path
import platform
import shutil
import subprocess
import sys
import os
import subprocess
import logging
import re

import pyparsing as pp

from reports import SuiteScatterPlotReport, BarChartReport, LatexTable, AggregateReport, MEAN_TIME_REPORT
from downward.experiment import FastDownwardExperiment, FastDownwardRun
from downward.reports.absolute import AbsoluteReport, PlanningReport
from downward.reports.scatter import ScatterPlotReport
from lab import tools
from lab.environments import BaselSlurmEnvironment, LocalEnvironment
from lab.experiment import ARGPARSER, get_default_data_dir
from lab.reports import Report, Attribute, geometric_mean


# Silence import-unused messages. Experiment scripts may use these imports.
assert BaselSlurmEnvironment and LocalEnvironment and ScatterPlotReport


DIR = Path(__file__).resolve().parent
NODE = platform.node()
REMOTE = NODE.endswith(".scicore.unibas.ch") or NODE.endswith(".cluster.bc2.ch")

def parse_args():
    return ARGPARSER.parse_args()


ARGS = parse_args()

EVALUATIONS_PER_TIME = Attribute(
    "evaluations_per_time", min_wins=False, function=geometric_mean, digits=1
)

def add_reports(exp, attributes):

    def filter_algo(alg_name):
        regex = re.compile(f'-{alg_name}$')
        def f(t):
            return bool(regex.search(t['algorithm']))
        #def f(t):
        #    #print(t['id'])
        #    return t['id'][0] == "20.06:00-lama" and t['id'][2] == "generated10.pddl"
        
        return f 
    
    def categorize_exit_codes(code):
        if code in range(0,9): return "Success"
        elif code in range(10,19): return "No Solution"
        elif code in [20, 22, 24]: return "Out of Memory"
        elif code in [21,23]: return "Out of Time"
        elif code in range(30,39): return "Fail"
        else:
            print(code) 
            return "fail"
    
    def categorize_by_planner(time, run):
        return run['algorithm']

    def pl(t):
        return "plan_length" in t and t["plan_length"] > 0

    def fmt_alg(alg):
        name = alg.split('-',1)
        return name[1] if len(name) > 1 else alg

    def sum_filter(values):
        def f(array):
            total = 0
            for item in array:
                if item in values:
                    total += 1
            return (total, len(array))
        return f

    def cat_improv(run1, run2):
        time1 = run1.get("search_time", 1800)
        time2 = run2.get("search_time", 1800)
        if time1 > time2:
            return "better"
        if time1 == time2:
            return "equal"
        return "worse"

    exp.add_report(MEAN_TIME_REPORT(
        attributes=['total_time'],
        #filter=pl
    ),
        outfile="tbl-time-mean.tex")

    #exp.add_report(RevScatterPlotReport('master', 'alg2'), outfile="scatter.png")
    exp.add_report(LatexTable(
        attributes=['planner_exit_code'],
        categorize=categorize_exit_codes,
        #filter=pl
    ), 
        outfile="tbl-exit-codes.tex")

    exp.add_report(AggregateReport(
        attributes=["pddl_objects"],
        filter=pl,
        med_func=lambda m: "{}".format(round(m['default'])),
        acc_func=max
    ), outfile="txt-max-pddl-default.tex")

    exp.add_report(AggregateReport(
        attributes=["expansions"],
        filter=pl,
        med_func=lambda m: "{}".format(round(m['default'])),
    ), outfile="txt-expansions-default-mean.tex")

    exp.add_report(AggregateReport(
        attributes=["expansions"],
        med_func=lambda m: "{}".format(round(m['ordered'])),
        filter=pl,
    ), outfile="txt-expansions-ordered-mean.tex")
    
    exp.add_report(AggregateReport(
        attributes=["cost"],
        med_func=lambda m: "{:.2f}".format(m['ordered']),
        filter=[pl, filter_algo("lama")],
    ), outfile="txt-cost-ordered-mean.tex")
    
    exp.add_report(AggregateReport(
        attributes=["cost"],
        med_func=lambda m: "{:.2f}".format(m['default']),
        filter=pl,
    ), outfile="txt-cost-default-mean.tex")

    exp.add_report(AggregateReport(
        attributes=["cost"],
        med_func=lambda m: "{:.2f}".format(m['no-derived-ordered']),
        filter=[pl, filter_algo("lama")],
    ), outfile="txt-cost-noderivedordered-mean.tex")

    exp.add_report(AggregateReport(
        attributes=["expansions"],
        med_func=lambda m: "{:.2f}\\%".format(100 * (m['ordered'] / m['default'])),
        filter=[pl, filter_algo("lama")],
    ), outfile="txt-expansions-ordered-default-percent.tex")

    exp.add_report(AggregateReport(
        attributes=["total_time"],
        med_func=lambda m: "{:.2f}".format(m['default']),
        filter=[pl],
    ), outfile="txt-time-default-mean.tex")

    exp.add_report(AggregateReport(
        attributes=["total_time"],
        med_func=lambda m: "{:.2f}".format(m['no-derived']),
        filter=[pl],
    ), outfile="txt-time-noderived-mean.tex")

    exp.add_report(AggregateReport(
        attributes=["total_time"],
        med_func=lambda m: "{:.2f}".format(m['ordered']),
        filter=[pl],
    ), outfile="txt-time-ordered-mean.tex")

    exp.add_report(AggregateReport(
        attributes=["total_time"],
        med_func=lambda m: "{:.2f}".format(m['no-derived-ordered']),
        filter=[pl],
    ), outfile="txt-time-noderivedordered-mean.tex")

    exp.add_report(BarChartReport(
        attributes=["planner_exit_code"],
        acc_func=sum_filter([0, 1, 2, 3]),
        fmt_alg=fmt_alg,
        filter=[]),
        outfile="bar-solved.png")

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'], 
        suites=['default', 'no-derived'], 
        category="algorithm",
        fmt_cat=fmt_alg,
        relative=False,
        filter=pl
        ), 
        outfile='scatter-time-algo-default-noderived.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'], 
        suites=['default', 'no-derived'], 
        category="algorithm",
        fmt_cat=fmt_alg,
        relative=True,
        filter=pl
        ), 
        outfile='scatter-time-algo-default-noderived-relative.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'], 
        suites=['default', 'no-derived'], 
        scale="log",
        category="algorithm",
        fmt_cat=fmt_alg,
        relative=True,
        filter=pl
        ), 
        outfile='scatter-time-algo-default-noderived-relative-log.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'],
        suites=['default', 'ordered'],
        category="algorithm",
        fmt_cat=fmt_alg,
        filter=pl),
        outfile='scatter-time-algo-default-ordered.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'],
        suites=['default', 'ordered'],
        relative=True,
        scale="log",
        category="algorithm",
        fmt_cat=fmt_alg,
        filter=pl),
        outfile='scatter-time-algo-default-ordered-relative-log.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'],
        suites=['default', 'no-derived-ordered'],
        relative=True,
        scale="log",
        category="algorithm",
        fmt_cat=fmt_alg,
        filter=pl),
        outfile='scatter-time-algo-default-noderivedordered-relative-log.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'],
        suites=['ordered', 'no-derived'],
        category="algorithm",
        fmt_cat=fmt_alg,
        filter=pl),
        outfile='scatter-time-algo-ordered-noderived.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'],
        suites=['no-derived', 'no-derived-ordered'],
        category="algorithm",
        fmt_cat=fmt_alg,
        relative=True,
        scale="log",
        filter=pl),
        outfile='scatter-time-algo-noderived-noderivedordered-relative-log.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['total_time'],
        suites=['default', 'no-derived-ordered'],
        category="pddl_objects",
        #fmt_cat=fmt_alg,
        markers="range",
        relative=True,
        scale="log",
        filter=pl),
        outfile='scatter-time-pddlobj-default-noderivedordered-relative-log.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['plan_length'], 
        suites=["default", "no-derived"],
        filter=[pl,filter_algo('lama')]),
        outfile='scatter-plan-default-noderived.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['plan_length'],
        suites=["default", "ordered"],
        filter=[pl,filter_algo('lama')]),
        outfile='scatter-plan-default-ordered.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['cost'],
        suites=["default", "no-derived"],
        filter=[pl,filter_algo('lama')]),
        outfile='scatter-cost-default-noderived.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['cost'],
        suites=["default", "ordered"],
        filter=[pl, filter_algo('lama')]),
        outfile='scatter-cost-default-ordered.png')
    
    exp.add_report(SuiteScatterPlotReport(
        attributes=['expansions'],
        suites=["default", "no-derived"],
        relative=True,
        scale="log",
        filter=[pl,lambda r: "expansions" in r and r["expansions"] > 0]),
        outfile='scatter-expansions-default-noderived.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['expansions'],
        suites=["default", "ordered"],
        relative=True,
        scale="log",
        filter=[pl,lambda r: "expansions" in r and r["expansions"] > 0]),
        outfile='scatter-expansions-default-ordered.png')

    exp.add_report(SuiteScatterPlotReport(
        attributes=['expansions'],
        suites=["default", "no-derived-ordered"],
        relative=True,
        scale="log",
        filter=[pl,lambda r: "expansions" in r and r["expansions"] > 0]),
        outfile='scatter-expansions-default-noderivedordered.png')
    
    exp.add_report(SuiteScatterPlotReport(
        attributes=['translator_variables'],
        suites=["default", "no-derived"],
        filter=[pl,filter_algo('astar-hmax')]),
        outfile='scatter-translaotr-varibales.png')
    
    exp.add_report(SuiteScatterPlotReport(
        attributes=['translator_axioms'],
        suites=["default", "no-derived"],
        relative=True,
        scale="log",
        filter=[pl,filter_algo('astar-hmax')]),
        outfile='scatter-translator-axioms.png')
    
    exp.add_report(SuiteScatterPlotReport(
        attributes=['translator_derived_variables'],
        suites=["default", "no-derived"],
        filter=[pl,filter_algo('astar-hmax')],
        ),
        outfile='scatter-derived-varibales.png')

    #exp.add_report(AbsoluteReport(attributes=attributes), outfile="report.html")
    #exp.add_report(PlanningReport(attributes=["pddl_objects", "memory"]), outfile="pddl_obj.html")

def copy_plots(exp, path):
    
    print("!!!!!!! PLOTS ARE NOT COPIED !!!!!!")
    return
    src = exp.eval_dir
    files = os.listdir(src)
    for f in files:
        if f != 'properties':
            shutil.copy(os.path.join(src, f), path)
    

def get_git_sha(rev):
    return subprocess.check_output(
            ['git', 'rev-list', '--abbrev-commit', '-n', '1', rev]
        ).decode("utf-8").rstrip('\n')

class CommonExperiment(FastDownwardExperiment):
    def __init__(self, only_reports=False, **kwargs):
        super().__init__(**kwargs)

        if not only_reports:
            self.add_step("build", self.build)
            self.add_step("start", self.start_runs)
            self.add_fetcher(name="fetch")

        self.add_parser(self.EXITCODE_PARSER)
        self.add_parser(self.TRANSLATOR_PARSER)
        self.add_parser(self.SINGLE_SEARCH_PARSER)
        self.add_parser(self.ANYTIME_SEARCH_PARSER)
        self.add_parser(self.PLANNER_PARSER)
        self.add_parser(DIR / "parser.py")

        self.suite_revision_cache = os.path.join(
            get_default_data_dir(), 'revision-cache-suite/'
        )
        self.suite_revs = []

        self.parsed_problems = {}
        self.problem_props = {}

    def add_suites(self, benchmark_dir, suite):
        suites = {
            # suite-name: problem[]
        }
        for suite_name, problem in suite:
            if not suite_name in suites:
                suites[suite_name] = []
            suites[suite_name].append(problem)
        
        for suite_name in suites:
            self.add_suite(benchmark_dir, suites[suite_name])
            #for p in suites[suite_name]:
            #    pddl_file = os.path.join(benchmark_dir, *p.split(':'))
            #    self.problem_props[pddl_file] = { }

    def add_suite_revision(self, benchmark_dir, suite):
        
        suites = {
            # menchmark_dir: (problem[], revision)
        }

        for rev, problem in suite:
            sha = get_git_sha(rev)
            # add dir and sha to list for later checkout
            # in build step
            self.suite_revs.append((benchmark_dir,sha))

            # full benchmark path in the rev cache 
            bm_dir = os.path.join(self.suite_revision_cache, sha, benchmark_dir)
            
            # group problems in the same path together
            # to the same suite

            if not bm_dir in suites:
                suites[bm_dir] = ([], rev)

            suites[bm_dir][0].append(problem)

        # Add suites to lab
        for bm_dir in suites:
            problems, rev = suites[bm_dir]
            # prevent error directionary doesn't exist
            os.makedirs(bm_dir, exist_ok=True)
            self.add_suite(bm_dir, problems)
            
            for p in problems:
                pddl_file = os.path.join(bm_dir, *p.split(':'))
                self.problem_props[pddl_file] = {
                    'id': [sha, rev],
                    'git_sha': sha,
                    'git_rev': rev,
                }
    
    def copy_suite_revisions(self):
        '''
        Performs a git checkout from all commits specified in 'self.suite_revs'
        and stores the directories in the revision-cache-suite folder.
        '''
        s = set(self.suite_revs)
        for bm_dir, sha in s:
            # Create folder
            path = os.path.join(self.suite_revision_cache, sha)
            path_benchmark = os.path.join(path, bm_dir)
            
            # Skip if a checkout has already been performed and 
            if len(os.listdir(path_benchmark)) > 0:
                prround("Suite revision already cached: ({}) {}".format(sha, bm_dir))
                continue
            else:
                print("Caching suite revision: ({}) {}".format(sha, bm_dir))

            os.makedirs(path, exist_ok=True)
            # checkout 
            cmd = ["git", "--work-tree={}".format(path),
                   'checkout', sha, '--', bm_dir]
            
            subprocess.run(cmd)

    def build(self, write_to_disk=True):
        
        #self.copy_suite_revisions()

        FastDownwardExperiment.build(self, write_to_disk=write_to_disk)


    def parse_problem_file(self, file):
        if file in self.parsed_problems:
            return self.parsed_problems[file]
        expr = pp.nestedExpr()
        [prob] = expr.parseFile(file)
        self.parsed_problems[file] = prob
        return prob
    def get_objects_count(self, parsed_prob):
        obj = None
        for cmd in parsed_prob:
            if len(cmd) > 0 and cmd[0] == ':objects':
                obj = cmd[1:]
        l = len(obj) - (obj.count('-') * 2)
        return l

    # overwrite default _add_runs method:
    # https://github.com/aibasel/lab/blob/479289238ef76d61f45cfebc1db42c539f6043c7/downward/experiment.py#L383
    def _add_runs(self):
        for algo in self._algorithms.values():
            for task in self._get_tasks():
                self.add_run(FastDownwardRun(self, algo, task))
        
        for r in self.runs:
            prob_file = r.task.problem_file
            prob = self.parse_problem_file(prob_file)
            r.set_property("pddl_objects", self.get_objects_count(prob) )

            print(self.problem_props, prob_file)

            #props = self.problem_props[prob_file]
            #for key in props:
            #    if key == 'id':
            #        r.properties["id"].extend(props['id'])
            #    else:
            #        r.properties[key] = props[key]
