#!/usr/bin/env python
import argparse
import glob
import os
import shutil
import sys
import tempfile

from graphviz import Digraph, Source
import pyparsing as pp


ONLY_RENDER = True


def parse_inval(text):
    parse_state = None
    partial_state = ''
    states = []
    actions = []

    for line in text.split('\n'):
        if parse_state == 'read_state':
            partial_state += line
            if '))' in line:
                parse_state = None
                states.append(partial_state)
                partial_state = ''
        
        elif parse_state == 'read_action':
            actions.append(line)
            parse_state = None
        
        else:
            if 'state:' in line:
                parse_state = 'read_state'
            elif 'next happening:' in line:
                parse_state = 'read_action'
    
    expr = pp.nestedExpr()
    states_parsed = []
    actions_parsed = ['initial_state']
    for state in states:
        s = expr.parseString(state)
        states_parsed.append(s[0].asList())
    for action in actions:
        a = expr.parseString(action)
        a = " ".join(a[0][0])
        actions_parsed.append(a)
    return (states_parsed, actions_parsed)

def parse_pddl(text):
    expr = pp.nestedExpr()
    parsed = expr.parseString(text)[0]
    states = []
    actions = []
    for expr in parsed:
        if expr[0] == ':init' or expr[0] == ':goal':
            states.append(expr[1:])
            actions.append(expr[0])
    # replace 'and' 
    for i, state in enumerate(states):
        if (len(state) > 0 
            and len(state[0]) > 0 
            and state[0][0].lower() == 'and'):
            states[i] = state[0][1:]
    return (states, actions)

def parse_pred(text):
    expr =pp.OneOrMore(pp.nestedExpr())
    parsed = expr.parseString(text)
    states = []
    actions = []
    for exp in parsed:
        actions.append(exp[0])
        states.append(exp[1:])
    return (states, actions)

def parse_text(t, text):
    if t == 'inval':
        return parse_inval(text)
    elif t == 'pddl':
        return parse_pddl(text)
    elif t == 'pred':
        return parse_pred(text)


def get_document(file):
    if file is None:
        if sys.__stdin__.isatty():
            print('no stdin pipe given. exiting..')
            exit(1)
        file = sys.stdin
    return file.read()


def key(graph, index, nodes):
    ids = []
    for n in nodes:
        nid = "i{}_{}".format(index, n)
        nname = n
        graph.node(nid, label=nname)
        ids.append(nid)
    return ids

def render_graph(states, actions, name, join, output, noLabel, view, skip_file):
    
    files = []
    if ONLY_RENDER:
        for i,s in enumerate(states):
            path = os.path.join(output, "{}-{}.gv".format(name, i))
            files.append(path)
        return files
    #Top level graph if join is enabled
    tlg = None
    label_loc = "t"

    if join or view:
        tlg = Digraph()
        tlg.attr(rankdir="BT", splines='false')
        label_loc = 'b'

    for i, s in enumerate(states):
        # TODO sort all not only first element        
        s.sort(key=lambda x: "".join(x[0]))
        dot = Digraph(name='cluster_{}'.format(i))
        action = ''
        if not noLabel:
            action = actions[i]
        dot.attr(style="dotted", rankdir="BT", splines='false', labelloc=label_loc, label=action)
        for statement in s:
            pred = statement[0].lower()
            args = statement[1:]

            if pred == 'n_is_parent':
                n = key(dot, i, args)
                dot.edge(n[1], n[0], weight="100")
            if pred == 'b_points_to':
                n = key(dot, i, args)
                sub = Digraph()
                sub.attr(rank='same')
                sub.node(n[0], shape='box')
                sub.edge(n[0], n[1], constraint="true", arrowhead='empty')
                dot.subgraph(sub)
            
            if pred == 'is_head':
                n = key(dot, i, args)
                dot.node(n[0], fillcolor='aquamarine', style='filled')
            
            if pred == 'n_is_copy_of':
                n = key(dot, i, args)
                dot.edge(n[0], n[1], style='dotted', constraint='false', arrowhead='vee', color="gray")
                #dot.node(n[1], color='gray75')
            
            if pred == '_color':
                n = key(dot, i, [args[0]])
                dot.node(n[0], fillcolor=args[1], style='filled')

        if join or view:
            tlg.subgraph(graph=dot)
            print("> subgraph", i)
        else:
            path = os.path.join(output, "{}-{}.gv".format(name, i))
            files.append(path)
            files.append(path + '.pdf')
            if not skip_file(path):
                dot.render(path)
                print("> writing file: ", path )
            else:
                print("> skipping file ", path)
    
    if join or view:
        if view:
            tlg.view(tempfile.mktemp('.gv'))
        else:
            path = os.path.join(output, "{}-joined.gv".format(name))
            files.append(path);
            files.append(path + '.pdf')
            if not skip_file(path):
                tlg.render(path)
                print("> writing file: ", path)
            else:
                print("> skipping file ", path)
    return files

def main():
    parser = argparse.ArgumentParser()
    parser.add_argument("--type", default="inval", choices=['inval', 'pddl', 'pred'])
    parser.add_argument("--name")
    parser.add_argument('--input', type=open, help="input file, default stdin")
    parser.add_argument("--join", default=False, action='store_true', help="Should the resulting graphs be joined")
    parser.add_argument('--output', '-o', help="output folder", default="output")
    parser.add_argument('--no-labels', action='store_true', default=False, help="Disable labels")
    parser.add_argument('--view', action='store_true', default=False, help="view only")
    parser.add_argument('--skip-newer', action='store_true', help="Skip if pdf  already exists and is newer than the source file")

    args = parser.parse_args()

    if args.name is None and not args.input is None:
        d = vars(args)
        input_path = args.input.name
        file_name = os.path.basename(input_path)
        d["name"] = os.path.splitext(file_name)[0]
    elif args.name is None:
        print("name or input file required: --name, --input")
        exit(1)

    if not args.input is None and not ONLY_RENDER:
        input_timestamp = os.path.getmtime(args.input.name)
    else:
        input_timestamp = None
    pyprob_timestamp = os.path.getmtime(__file__)

    print("reading input")
    document = get_document(args.input)
    print("parsing states")
    states, actions = parse_text(args.type, document)

    output_dir = os.path.join(args.output, args.name)
    output_timestamp = {}
    for file in glob.glob( os.path.join(output_dir, '*.gv*')):
        output_timestamp[file] = os.path.getmtime(file)

    def skip(file):
        if not args.skip_newer:
            return False
        if not file in output_timestamp:
            return False
        return output_timestamp[file] > input_timestamp and output_timestamp[file] > pyprob_timestamp
    
    print("rendering")
    out_files = render_graph(states, actions, args.name, args.join, output_dir, args.no_labels, args.view, skip)
    print(out_files)
    if ONLY_RENDER:
        for f in out_files:
            with open(f, "r") as fi:
                src = Source(fi.read())
            dest = src.render(f)
            print(f"Rendered file {f} to {dest}")
    else:
        print(f'cleaning output dir {output_dir}')
        for f in output_timestamp:
            if not f in out_files:
                print("> deleting", f )
                os.remove(f)
    pass



if __name__ == "__main__":
    #print("Disabled")
    main()
