from Assignment import Assignment
from IfStatement import IfStatement
from WhileStatement import WhileStatement
from SkipStatement import SkipStatement
from PPCondition import PPCondition
import re

class Parser(object):

    def __init__(self, onlyStrings=False):
        self.variables = []
        self.statements = []
        self.pre = ""
        self.post = ""
        self.onlyStrings = onlyStrings


    def parseProgram(self, triple, hasPreAndPost) -> (list, list, list, list, int):
        '''

        :param triple: Hoare triple consisting of precondition (obtional), statements and postcondition (optional)
        :param hasPreAndPost: if this param is False, then the triple only consists of statements
        :return: list of preconditions, list of postconditions, list of variables, list of statements, current program index
        '''

        # the most beautiful line ;) removes empty lines from the input triple
        triple = [x for x in triple if x]
        if hasPreAndPost:
            # parse pre- and postcondition
            pre = self.parseCondition(triple[0])
            post = self.parseCondition(triple[-1])
            if pre[0]:
                self.pre = pre[1]
            else:
                print("not a valid precondition")
            if post[0]:
                self.post = post[1]
            else:
                print("not a valid postcondition")

            # extract the statements from the input triple
            program = triple[1:-1]
        else:
            program = triple

        # main parsing loop
        i = 0
        while i < len(program):
            s = program[i]
            if self.isAssignment(s):
                self.parseAssignment(s)
                i += 1
            elif self.isIfStatement(s):
                i = i + self.parseIfStatement(program, i)
            elif self.isWhileStatement(s):
                i = i + self.parseWhileStatement(program, i)
            elif self.isSkipStatement(s):
                self.parseSkipStatement(s)
                i = i + 1
            elif s.replace(' ', '') == 'END':
                # stop parsing a sub-program
                return self.pre, self.post, self.variables, self.statements, i
            else:
                # for empty lines
                i += 1

        return self.pre, self.post, self.variables, self.statements, i


    def isAssignment(self, statement) -> bool:
        statement = statement.replace(' ', '')
        return re.match(r'\w+\=', statement)

    def isIfStatement(self, statement) -> bool:
        statement = statement.replace(' ', '')
        return statement.startswith("IF")

    def isWhileStatement(self, statement) -> bool:
        statement = statement.replace(' ', '')
        return statement.startswith('WHILE')

    def isSkipStatement(self, statement) -> bool:
        statement = statement.replace(' ', '')
        return statement.startswith('SKIP')

    def parseAssignment(self, statement):
        splitStatement = statement.replace(' ', '').split('=')
        if not (len(splitStatement) == 2):
            print("not a valid assignment")
            return

        # create new assignment and add it to the list of statements
        a = Assignment(splitStatement[0], splitStatement[1], self.onlyStrings)
        self.statements.append(a)

        # check if a new variable is declared and if so add it to the list of variables
        if not self.variableHasBeenDeclared(splitStatement[0]):
            self.addVariable(splitStatement[0])


    def parseCondition(self, condition) -> (bool, str):
        # remove whitespaces
        condition = condition.replace(" ", "")
        # check if it is a condition and remove brackets
        if condition.startswith("{") and condition.endswith("}"):
            condition = condition.replace('{', '')
            condition = condition.replace('}', '')
            if self.onlyStrings:
                return True, condition
        else:
            return False, condition

        preprocessedConditions = condition.split(',')
        parsedConditions = []
        for c in preprocessedConditions:
            comp = re.findall("\\<\\=|\\>\\=|\\<|\\>|\\=\\=|\\!\\=", c)[0]
            left = c.split(comp)[0]
            right = c.split(comp)[1]
            parsedConditions.append(PPCondition(left, comp, right))

        return True, parsedConditions


    def parseSkipStatement(self, statement):
        self.statements.append(SkipStatement())


    def parseIfStatement(self, program, index) -> int:
        '''
        adds an IfStatement to self.statements
        returns the program index after the if statement
        '''
        # Extract branching condition
        branchingCondition = program[index].replace('IF', '').replace(' ', '')

        # first part of the if statement
        parser1 = Parser(self.onlyStrings)
        pre, post, variablesC1, statementsC1, endIndexC1 = parser1.parseProgram(program[index + 1:], False)
        remainingProgram = program[index + endIndexC1 + 3:]

        # else statement
        parser2 = Parser(self.onlyStrings)
        pre, post, variablesC2, statementsC2, endIndexC2 = parser2.parseProgram(remainingProgram, False)
        increment = endIndexC1 + endIndexC2 + 4

        # create new IfStatement object
        statement = IfStatement(branchingCondition, statementsC1, statementsC2)
        self.statements.append(statement)
        return increment

    def parseWhileStatement(self, program, index) -> int:
        '''
        adds a WhileStatement to self.statements
        returns the program index after the while statement
        '''
        # Extract branching condition
        branchingCondition = program[index].replace('WHILE', '').replace(' ', '')

        # extract loop invariant
        invariant = program[index + 1].replace('{', '').replace('}', '').replace(' ', '')
        parser = Parser(self.onlyStrings)
        pre, post, variables, statements, endIndex = parser.parseProgram(program[index + 2:], False)
        increment = endIndex + 3
        self.statements.append(WhileStatement(branchingCondition, statements, invariant))
        return increment

    def variableHasBeenDeclared(self, variable) -> bool:
        '''
        check if a given variable has been declared and is in self.variables
        We use this method so we could adapt the data representation of the variables
        '''
        return variable in self.variables

    def addVariable(self, variable):
        self.variables.append(variable)