#! /usr/bin/env python

"""Solve some tasks with A* and the LM-Cut heuristic."""
from __future__ import division
import os
import os.path
import platform
import re
import suites

from lab.environments import BaselSlurmEnvironment

from downward.experiment import FastDownwardExperiment
from downward.reports.absolute import AbsoluteReport
from downward.reports.scatter import ScatterPlotReport
from downward.reports.compare import ComparativeReport
from downward.reports.taskwise import TaskwiseReport




SUITE = suites.suite_optimal_strips()
ENV = BaselSlurmEnvironment()
REPO = "/infai/oldphi00/master_code"
BENCHMARKS_DIR = os.environ["DOWNWARD_BENCHMARKS"]

exp = FastDownwardExperiment(environment=ENV)
exp.add_suite(BENCHMARKS_DIR, SUITE)

time_limit = "30m"


timesteps_list = [0, 2, 4, 6]
for step in timesteps_list:
    exp.add_algorithm(
        'timerep{}sync'.format(step), REPO, 'default', ['--search', 'astar(time(wr=true, ts={}, tv=true))'.format(step)], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])
    exp.add_algorithm(
        'timerep{}unsync'.format(step), REPO, 'default', ['--search', 'astar(time(wr=true, ts={}, st=0, tv=true))'.format(step)], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])
    exp.add_algorithm(
        'timerep{}syncip'.format(step), REPO, 'default', ['--search', 'astar(time(wr=true, ts={}, tv=true, ip=true))'.format(step)], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])
    exp.add_algorithm(
        'timerep{}syncnottv'.format(step), REPO, 'default', ['--search', 'astar(time(wr=true, ts={},tv=false))'.format(step)], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# LM-Cut
exp.add_algorithm(
    'lmcut', REPO, 'default', ['--search', 'astar(lmcut())'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

# State equation heuristic
exp.add_algorithm(
    'stateequ', REPO, 'default', ['--search', 'astar(operatorcounting([state_equation_constraints()]))'], driver_options=["--search-time-limit", time_limit, '--build', 'release64'], build_options=["release64"])

lmcut_costs = dict()

def init_lmcut_costs(run):
    if run["algorithm"] == "lmcut":
        if "cost" in run:
            lmcut_costs[run["domain"] + run["id"][2]] = run["cost"]
    return run


def debug(run):
    # for part in run:     
    #     print(part, run[part])
    return run

initial_h_values = dict()

def init_initial_h_values(run):
    if("time" in run["algorithm"]):
        if "initial_h_value" in run:
            initial_h_values[run["algorithm"] + run["domain"] + run["id"][2]] = run["initial_h_value"]
    return run


# Create reports that are actually used in the thesis
def removelargehvalues(run):
    return run.get("initialoroptimal", 100) < 100

def removelargehvalues2(run):
    return run.get("initial_h_value", 10000) <= 50

def lowervalues(run):
    if run.get("initial_h_value", 0) > 50:
        run["initial_h_value"] = 50
    return run

def getPercentage(run):
    if(run["domain"] + run["id"][2] in lmcut_costs) and "initial_h_value" in run:
        run["percentage"] = run["initial_h_value"] / lmcut_costs[run["domain"] + run["id"][2]]
    if (run["domain"] + run["id"][2] in lmcut_costs) and "initial_h_value" in run:
        return run
    else:
        return False

def getPercentage100(run):
    if(run["domain"] + run["id"][2] in lmcut_costs) and "initial_h_value" in run:
        run["percentage"] = (100.0 * run["initial_h_value"]) / lmcut_costs[run["domain"] + run["id"][2]] # Round to 3 decimal digits
    return run

def only_optimal_vs_competitors(run):
    run['domain'] = run['domain'].replace("-strips", "")
    return run['algorithm'] in ['stateequ', 'lmcut', 'timerep0sync', 'timerep2sync', 'timerep4sync', 'timerep6sync']

def only_optimal_vs_competitors_without_parcprinter(run):
    run['domain'] = run['domain'].replace("-strips", "")
    return run['algorithm'] in ['stateequ', 'lmcut', 'timerep0sync', 'timerep2sync', 'timerep4sync', 'timerep6sync'] and 'parcprinter' not in run['domain']

def domain_as_category(run1, run2):
    if run1["algorithm"] == 'stateequ':
        if run1["percentage"] > run2["percentage"]:
            print "fail in"
            print run1['domain']
            print run1["id"][2]
        return run1["percentage"] > run2["percentage"]
    else:
        if run1["percentage"] < run2["percentage"]:
            print "fail in"
            print run1['domain']
            print run1["id"][2]
        return run1["percentage"] < run2["percentage"]


# Initial h value and coverage tables
exp.add_report(
    AbsoluteReport(attributes=["Summary", "coverage"], filter=only_optimal_vs_competitors,  format='tex'), outfile='final_coverage_static.tex')

exp.add_report(
    AbsoluteReport(attributes=["Summary", "initial_h_value"], filter=only_optimal_vs_competitors_without_parcprinter,  format='tex'), outfile='final_initialhvalue_static.tex')

# Percentage scatter plot of ATUR 2 and ATUR 6 vs State Equation
exp.add_report(
    ScatterPlotReport(
        attributes=["percentage"], filter_algorithm=["timerep2sync", "stateequ"], filter=[init_lmcut_costs,getPercentage], format="tex"),
    outfile='final_scatter_ATUR2_perc_6.tex')

exp.add_report(
    ScatterPlotReport(
        attributes=["percentage"], filter_algorithm=["timerep6sync", "stateequ"], filter=[init_lmcut_costs,getPercentage], format="tex"),
    outfile='final_scatter_ATUR6_perc.tex')

# Percentage scatter plot of ATUR 10 and ATUR 15 vs one less time steps
exp.add_report(
    ScatterPlotReport(
        attributes=["percentage"], filter_algorithm=["timerep10sync", "timerep6sync"], filter=[init_lmcut_costs,getPercentage100], format="tex"),
    outfile='final_scatter_ATUR10_perc.tex')

exp.add_report(
    ScatterPlotReport(
        attributes=["percentage"], filter_algorithm=["timerep15sync", "timerep10sync"], filter=[init_lmcut_costs,getPercentage100], format="tex"),
    outfile='final_scatter_ATUR15_perc.tex')

# IP
exp.add_report(
    ScatterPlotReport(
        attributes=["percentage"], filter_algorithm=["timerep0syncip", "timerep0sync"], filter=[init_lmcut_costs,getPercentage100], format="tex"),
    outfile='final_scatter_ATUR0IP_perc.tex')

exp.add_report(
    ScatterPlotReport(
        attributes=["percentage"], filter_algorithm=["timerep2syncip", "timerep2sync"], filter=[init_lmcut_costs,getPercentage100], format="tex"),
    outfile='final_scatter_ATUR2IP_perc.tex')
    
exp.add_report(
    ScatterPlotReport(
        attributes=["percentage"], filter_algorithm=["timerep6syncip", "timerep6sync"], filter=[init_lmcut_costs,getPercentage100], format="tex"),
    outfile='final_scatter_ATUR6IP_perc.tex')




# Debugging
def only_stateequ_vs_ATUR2(run):
    run['domain'] = run['domain'].replace("-strips", "")
    return run['algorithm'] in ['stateequ','timerep2sync']

algorithm_pairs = [('timerep2sync', 'stateequ')]
exp.add_report(ComparativeReport(algorithm_pairs, filter=[init_lmcut_costs,getPercentage100], attributes=['percentage']))

hseq_values = dict()
def init_hseq_values(run):
    if run["algorithm"] in ["stateequ"] and "initial_h_value" in run:
        hseq_values[run["domain"] + run["id"][2]] = run["initial_h_value"]
    return run

def get_is_bigger_than(run):
    if run["algorithm"] in ["timerep6sync"] and "initial_h_value" in run and run["domain"] + run["id"][2] in hseq_values:
        run["is_bigger_than"] = int(run["initial_h_value"] < hseq_values[run["domain"] + run["id"][2]])
    return run

exp.add_report(AbsoluteReport(
    attributes=["is_bigger_than","Summary"], filter=[init_hseq_values,get_is_bigger_than],
    filter_algorithm=["timerep6sync"]), outfile='debugging_h_value.html')




# Parse the commandline and show or run experiment steps.
exp.run_steps()
