#!/usr/bin/env python
import argparse
import random
import os
import copy


VARIANTS = [
  ("model/default", ['n_is_ancestor']),
  ("model/no-derived", []),
]

TEMPLATE = '''
(define (problem {name}) (:domain test_graph)
(:objects 

{objects}

)

(:init

{initial_state}
(= (total-cost) 0)

)

(:goal (
    and

{goal_state}

  )
)

(:metric minimize (total-cost))
)
'''

class State:

  def __init__(self, steps):
    self.steps = steps
    self.steps_cur = 0

    self.node_count = 0
    self.branch_count = 0

    self.nodes = []
    self.branches = []

    self.graph = {}
    self.head = None
    self.branch_pointer = {}
    self.copies = {}

    self.actions = [
      (self.__action_commit, 5),
      (self.__action_branch, 1),
      (self.__action_checkout, 3),
      (self.__action_merge, 2),
      (self.__action_rebase, 1),
    ]


  def done(self):
    return self.steps <= self.steps_cur

  def __add_parent(self, child, parent):
    if not child in self.graph:
      self.graph[child] = []
    self.graph[child].append(parent)

  def __get_new_node(self):
    n = 'n{}'.format(len(self.nodes))
    self.nodes.append(n)
    return n

  def __get_new_branch(self):
    b = 'b{}'.format(len(self.branches))
    self.branches.append(b)
    return b

  def __head_node(self):
    if 'n' in self.head:
      return self.head
    else:
      return self.branch_pointer[self.head]
  
  def __is_ancestor(self, ancestor, child):
    if not child in self.graph:
      return False # child is root
    parents = self.graph[child]
    if ancestor in parents:
      return True
    for parent in parents:
      if self.__is_ancestor(ancestor, parent):
        return True
    return False


  def __move_pointer(self, n):
    '''
    Move head or branch pointer to node
    '''
    p = self.head
    if 'b' in p:
      self.branch_pointer[p] = n
    else:
      self.head = n

 ## Actions that modify the graph
  def __action_commit(self):
    child = self.__get_new_node()
    self.__add_parent(child, self.__head_node())
    self.__move_pointer(child)
    print("; commit {}".format(child))
    return True

  def __action_branch(self):
    branch = self.__get_new_branch()
    self.branch_pointer[branch] = self.__head_node()
    print("; branch {} ({})".format(branch, self.__head_node()))
    return True

  def __action_checkout(self):
    branch = random.uniform(0,1) > 0.8
    if branch:
      self.head = random.choice(self.branches)
    else:
      self.head = random.choice(self.nodes)
    print("; checkout {}".format(self.head))
    return True

  def __action_merge(self):
    branch = random.choice(self.branches)
    bnode = self.branch_pointer[branch]
    head = self.head

    if ('n' in self.head or self.__head_node() == bnode):
      # selected branch points to same as head or 
      # we are in detached head state
      return False
    
    if self.__is_ancestor(self.__head_node(), bnode):
      if 'n' in self.head:
        self.head = bnode
      else:
        self.branch_pointer[self.head] = bnode
    else:
      self.__action_commit()
      self.__add_parent(self.__head_node(), bnode)

    print("; merge {} -> {} head:{}".format(head, branch,self.head ))
    return True

  def __action_rebase(self):
    head = self.__head_node()
    b_other = random.choice(self.branches)
    n_other = self.branch_pointer[b_other] 

    # rebase onto self
    if n_other == head:
      return False
    
    # already rebased
    if self.__is_ancestor(head, n_other) or self.__is_ancestor(n_other, self):
      return False

    # first common ancestor
    fca_queue = list(map(lambda x: [head,x], self.graph[head].copy()))

    fca = None
    while fca_queue:
      path = fca_queue.pop()
      node = path[-1]
      if self.__is_ancestor(node, n_other):
        fca = path
        break
      
      for parent in (self.graph[node] if node in self.graph else []):
        fca_queue.append(path + [parent])

    # no common ancestor found
    if fca is None:
      return False
    
    prev = n_other
    # only copy single route
    for node in fca:
      new = self.__get_new_node()
      self.graph[new] = [prev]
      self.copies[new] = node
      prev = new
    
    print("; rebase {head} {branch}({nother})  copies:{list}".format(
      head=head, branch=b_other, nother=n_other, list=fca))
    return True

  def generate_initial_state(self):
    assert len(self.graph) == 0
    root = self.__get_new_node()
    branch = self.__get_new_branch()

    choice = random.choice(range(4))
    if choice == 0:
      child = self.__get_new_node()
      self.__add_parent(child, root)
      self.branch_pointer[branch] = child
    elif choice == 1:
      self.branch_pointer[branch] = root
    elif choice == 2:
      child = self.__get_new_node()
      branch2 = self.__get_new_branch()
      self.branch_pointer[branch] = root
      self.branch_pointer[branch2] = child
      self.__add_parent(child, root)
    elif choice == 3:
      left = self.__get_new_node()
      right = self.__get_new_node()
      center = self.__get_new_node()
      branch2 = self.__get_new_branch()
      self.branch_pointer[branch] = center
      self.branch_pointer[branch2] = left
      self.__add_parent(left, root)
      self.__add_parent(right, root)
      self.__add_parent(center, right)
      self.__add_parent(center, left)
    self.head = branch

  def perform_action(self):
    actions = list(map(lambda x: x[0], self.actions))
    weights = list(map(lambda x: x[1], self.actions))
    
    action = random.choices(actions, weights=weights, k=1)[0]
    if action():
      self.steps_cur += 1
      

  def format_state(self, is_init, extensions, commented_predicates):
    predicates = []
    

    # parent predicates
    for child in self.graph:
      parents = self.graph[child]
      ancestors = []
      ancestors.extend(self.graph[child])
      for parent in parents:
        predicates.append(['n_is_parent', parent, child])
      if is_init:
        for anc in ancestors:
          if anc in self.graph:
            ancestors.extend(self.graph[anc])
          predicates.append(['n_is_ancestor', anc, child])
    
    # head predicate
    predicates.append(['is_head', self.head])
    # branch predicates
    for branch in self.branch_pointer:
      node = self.branch_pointer[branch]
      predicates.append(['b_points_to', branch, node])
      if is_init:
        predicates.append(['is_in_graph', branch])
    
    # node in graph predicates
    for node in self.nodes:
      if is_init:
        predicates.append(['is_in_graph', node])

    for copy in self.copies:
      src = self.copies[copy]
      predicates.append(['n_is_copy_of', copy, src])

    for ext in extensions:
      ext(self, predicates, is_init)

    # format predicates
    predicates = ["({})".format(" ".join(p)) for p in predicates]
    predicates = list(set(predicates))
    
    # comment out some predicates
    for i, pred in enumerate(predicates):
      if any(name in pred for name in commented_predicates):
        predicates[i] = f';{pred}'

    predicates.sort()
    return "\n".join(predicates)

  def format_objects(self, extensions):
    # double amount of nodes to allow for copies
    n = []
    n.extend(self.nodes)
    
    b = self.branches

    for ext in extensions:
      ext(self, n, b)
    
    print("; Total: {} nodes ({} extra), {} branches".format(
      len(n), len(n) - len(self.nodes), len(self.branches)))
    return '{} - node \n{} - branch'.format(" ".join(n), " ".join(b))
  
##### EXTENSIONS ######
# Formatting
def format_extensions(args):
  extensions = []
  return extensions
# Objects
def object_extensions(args):
  
  # add more nodes
  def extend_objects(state, nodes, branches):
    for i in range(len(state.nodes)):
        nodes.append('n{}'.format(i + len(state.nodes)))

  extensions = []
  if args.ext_extend_nodes:
    extensions.append(extend_objects)
  return extensions


def generate(args):
  print(f'; GENERATING steps={args.steps}')
  state = State(args.steps)
  state.generate_initial_state()
  inits = copy.deepcopy(state)

  while not state.done():
    state.perform_action()

  for folder, commented_predicates in VARIANTS:
    goals = state.format_state(False, format_extensions(args), commented_predicates)
    init = inits.format_state(True, format_extensions(args), commented_predicates)

    objects = state.format_objects(object_extensions(args))
    problem = TEMPLATE.format(objects=objects, name=args.name, initial_state=init, goal_state=goals)
    
    with open(os.path.join(folder, args.file), "w") as f:
      f.write(problem)
  print('; DONE GENERATING')
  print()


def exists(name):
  for folder, _ in VARIANTS:
    if os.path.exists(os.path.join(folder, name)):
      return True
  return False

def nextFilename():
  base = "generated"
  number = 1

  while (exists(f"{base}{number}.pddl")):
    number += 1

  return f"{base}{number}"
  


def main():
  parser = argparse.ArgumentParser(description='Generate pddl problems')
  parser.add_argument('--number', default=1, type=int, help="Number of problems to generate")
  parser.add_argument('--steps', '-s', default=6, type=int, help="number of steps to take")
  parser.add_argument('--increment', action="store_true", help="increment from (--steps) - (--inc-div) to (--steps) + (--inv-div) steps where every step number of steps generates (--number) problem files")
  parser.add_argument('--inc-div', default=0, type=int, help="see --increment")
  parser.add_argument('--ext-extend-nodes', action='store_true', help="Double the amount of nodes to allow for copying")
  args = parser.parse_args()


  if args.increment:
    assert args.inc_div > 0
  else:
    assert args.inc_div == 0

  increment_steps = range(args.steps - args.inc_div, args.steps + args.inc_div +1)

  allfiles = []

  for steps in increment_steps:
    for i in range(args.number):
      filename = nextFilename()
      d = vars(args)
      d["steps"] = steps
      d["name"] = filename
      d["file"] = f"{filename}.pddl"
      allfiles.append(d["file"])
      generate(args)

  print(f"; Generated {len(allfiles)} problems")
  for f in allfiles:
    print(f"; - {f}")

if __name__ == "__main__":
    main()
