#! /usr/bin/env python

"""Solve some tasks with A* and the LM-Cut heuristic."""

import os
import os.path
import platform
import re
import suites

from lab.environments import BaselSlurmEnvironment

from downward.experiment import FastDownwardExperiment
from downward.reports.absolute import AbsoluteReport
from downward.reports.scatter import ScatterPlotReport




SUITE = suites.suite_optimal_strips()
ENV = BaselSlurmEnvironment()
REPO = "/infai/oldphi00/master_code"
BENCHMARKS_DIR = os.environ["DOWNWARD_BENCHMARKS"]

exp = FastDownwardExperiment(environment=ENV)
exp.add_suite(BENCHMARKS_DIR, SUITE)

time_limit = "30m"

# Without tv
exp.add_algorithm(
    '1-ianottv', REPO, 'default', ['--search', 'astar(time(ia=true,tv=false,rf=false,ip=false,rd=0,transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# IP
exp.add_algorithm(
    '1-iaip', REPO, 'default', ['--search', 'astar(time(ia=true,tv=true,rf=false,ip=true,rd=0,transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# TV
exp.add_algorithm(
    '1-iatv', REPO, 'default', ['--search', 'astar(time(ia=true,tv=true,rf=false,ip=false,rd=0,transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# RD can't be combined I think
exp.add_algorithm(
    '1-iard', REPO, 'default', ['--search', 'astar(time(ia=true,tv=false,rf=false,ip=false,rd=1,transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# RF
exp.add_algorithm(
    '1-iarf', REPO, 'default', ['--search', 'astar(time(ia=true,tv=true,rf=true,ip=false,rd=0,transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])
    
# WITH LI
# TV
exp.add_algorithm(
    'li-iatv', REPO, 'default', ['--search', 'astar(time(ia=true,tv=true,rf=false,ip=false,rd=0,li=true,transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# LM-Cut
exp.add_algorithm(
    '2-lmcut', REPO, 'default', ['--search', 'astar(lmcut(transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# State equation heuristic
exp.add_algorithm(
    '2-stateequ', REPO, 'default', ['--search', 'astar(operatorcounting([state_equation_constraints()],transform=adapt_costs(cost_type=ONE)),cost_type=ONE)'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# Some timesteps with repetition and IP
# for step in timesteps_list:
#    exp.add_algorithm(
#        'timerepip{}gl'.format(step), REPO, 'default', ['--search', 'astar(time(wr=true, ts={}, #gl=true, ip=true))'.format(step)], driver_options=["--search-time-limit", time_limit, '--build', #'release64'], build_options=["release64"])
#    exp.add_algorithm(
#        'timerepip{}rf'.format(step), REPO, 'default', ['--search', 'astar(time(wr=true, ts={}, #rf=true, ip=true))'.format(step)], driver_options=["--search-time-limit", time_limit, '--build', #'release64'], build_options=["release64"])


# Make a report (AbsoluteReport is the standard report).
exp.add_report(
    AbsoluteReport())

exp.add_report(
    AbsoluteReport(attributes=["Unexplained Errors", "Info", "Summary", "cost", "coverage",
        "evaluations_until_last_jump", "expansions_until_last_jump", "generated_until_last_jump", 
        "initial_h_value", "memory", "reopened_until_last_jump", "search_time", "error"]), outfile='myreport.html')

lmcut_costs = dict()

def init_lmcut_costs(run):
    if run["algorithm"] == "lmcut":
        if "cost" in run:
            lmcut_costs[run["domain"] + run["id"][2]] = run["cost"]
    return run

# Premise fulfilled if 0, and 1 otherwise, so that positive sum means there's something wrong.
def sanity_check_1(run):
    if "initial_h_value" in run and run["domain"] + run["id"][2] in lmcut_costs:
        if run["initial_h_value"] <= lmcut_costs[run["domain"] + run["id"][2]]:
            run["(1)"] = 0
        else:
            run["(1)"] = 1
    # If the value is not yet set, that means that the test isn't applicable to this run, but we still want to consider other runs with the same problem and therefore set it to zero.
    # The absolute report does only have entries where all algorithm of a problem have values.
    if("(1)" not in run):
        run["(1)"] = 0
    return run

def debug(run):
    # for part in run:     
    #     print(part, run[part])
    return run

initial_h_values = dict()

def init_initial_h_values(run):
    if("time" in run["algorithm"]):
        if "initial_h_value" in run:
            initial_h_values[run["algorithm"] + run["domain"] + run["id"][2]] = run["initial_h_value"]
    return run

def sanity_check_2(run):
    if("timerep" in run["algorithm"]):
        timesteps = re.search(r'(\d+)$', run["algorithm"]).group(0)
        check = 0 # If no counter example is found in the following loop, the check fails.
        for step in (x for x in timesteps_list if x > timesteps):
            if (initial_h_values["timerep" + step + run["domain"] + run["id"][2]] < run["initial_h_value"]):
                check = 1
                break
        run["(2)"] = check
    if("(2)" not in run):
        run["(2)"] = 0
    return run

def sanity_check_3(run):
    if("timeitip" in run["algorithm"]):
        if("initial_h_value" in run):
            if(run["domain"] + run["id"][2] in lmcut_costs):
                if(run["initial_h_value"] == lmcut_costs[run["domain"] + run["id"][2]]):
                    run["(3)"] = 0
                else:
                    run["(3)"] = 1
    if("(3)" not in run):
        run["(3)"] = 0
    return run

def sanity_check_4(run):
    if("time" in run["algorithm"]):
        if("initial_h_value" in run):
            if("stateequ" + run["domain"] + run["id"][2] in initial_h_values):
                if(run["initial_h_value"] >= initial_h_values["stateequ" + run["domain"] + run["id"][2]]):
                    run["(4)"] = 0
                else:
                    run["(4)"] = 1
    if("(4)" not in run):
        run["(4)"] = 0
    return run

def sanity_check_5(run):
    if("time" in run["algorithm"]):
        if("ip" in run["algorithm"]):
            if("initial_h_value" in run and run["algorithm"].replace("ip", "") + run["domain"] + run["id"][2] in initial_h_values):
                if(run["initial_h_value"] >= initial_h_values[run["algorithm"].replace("ip", "") + run["domain"] + run["id"][2]]):
                    run["(5)"] = 0
                else:
                    run["(5)"] = 1
    if("(5)" not in run):
        run["(5)"] = 0
    return run

# Check if heuristic value of first state is smaller than or equal to optimal plan cost. (1)
# Check if heuristic value of first state is only bigger than or equal to first state heuristic values of heuristics with less time steps. (2)
# Check if the irative heuristic with IP is perfect (by comparing its initial h value with the cost of a perfect plan. (3)
# Check if initial h values of the time heuristics are better than the ones of the state equation. (4)
# Check if IP versions have better initial h values. (5)
exp.add_report(
    AbsoluteReport(attributes=["(1)", "(2)", "(3)", "(4)", "(5)", "Summary"], filter=[debug, init_lmcut_costs, sanity_check_1, init_initial_h_values, sanity_check_2, sanity_check_3, sanity_check_4, sanity_check_5]), outfile='sanitychecks.html')

# Compare the number of expansions in a scatter plot.
exp.add_report(
    ScatterPlotReport(
        attributes=["expansions"], filter_algorithm=["timeit", "timerep0"]),
    outfile='scatterplot.png')

def add_initialoroptimalattribute(run):
    if run["algorithm"] == "lmcut":
        run["initialoroptimal"] = run.get("cost")
    else:
	run["initialoroptimal"] = run.get("initial_h_value")
    return run

def removelargehvalues(run):
    return run.get("initialoroptimal", 100) < 100

# Compare the number of expansions in a scatter plot.
exp.add_report(
    ScatterPlotReport(
        attributes=["expansions"], filter_algorithm=["timeit", "timerep0"]),
    outfile='scatterplot.png')
exp.add_report(
    ScatterPlotReport(
        attributes=["initialoroptimal"], filter_algorithm=["lmcut", "timerep5"], filter=[add_initialoroptimalattribute, removelargehvalues]),
    outfile='initialhvalues.png')

# Sanity checks
def test(run):
    print(run["algorithm"])
    for i in run:
        print(i, run[i])

# lambda run: run.algo.name == "lmcut", lambda run: run.algo.name
exp.add_report(
    ScatterPlotReport(
        attributes=["expansions"], filter=lambda run: run["algorithm"] == "lmcut", get_category=lambda run: run["algorithm"]),
    outfile='scatterplot_sanity_domination.png')

# Parse the commandline and show or run experiment steps.
exp.run_steps()
