from nmigen import *
from nmigen.cli import main

from nmigen.back.pysim import Simulator, Delay

from functools import reduce


from typing import List, Dict, Tuple, Optional


from combinatorial_LR_parser import MasterStateMachine


import random


import CFGBoltzmann

class TreeNode:
    def __init__(self, language_element, subnodes):
        self.language_element = language_element
        self.subnodes = subnodes

    def __str__(self):
        return str(hex_to_name(self.language_element)) + str(self.subnodes)


def walk_the_tree(tree, level = 0):
    if tree == None:
        return

    print("    " * (level) + "|---" + hex_to_name(tree.language_element) + " with " + str(len(tree.subnodes)) + " subnodes")
    for subnode in tree.subnodes:
        walk_the_tree(subnode, level + 1)


class Cirno(Elaboratable):
    def __init__(self):
        self.input_memory_addr = Signal(8)
        self.input_memory_data = Signal(16)

    def elaborate(self, platform):
        cntr = Signal(8)
        resetted = Signal(1)

        valid_data_out = Signal(1)
        romem_ready = Signal(1)
        data_port = Signal(16)


        BOTTOM      = 0x5a00
        ENDOFPARSE  = 0x5500

        EXPRESSION  = 0xEE00
        TERM        = 0xE700
        FACTOR      = 0xEF00
        INTEGER     = 0xE100

        OPENPAREN   = 0xCA00
        CLOSEPAREN  = 0xCB00

        ADDOP       = 0xAA00
        MULTOP      = 0xAB00
        STDMASK     = 0xff00

        m = Module()
        # BOTTOM = start of parse



        common_stack_states = [
        # State 0 in the paper
        # Bottom of parse stack
        [(0, STDMASK, BOTTOM)],

        # State 1 in the paper
        # Bottom of parse stack        Expression
        [(0, STDMASK, EXPRESSION)],

        # state 2 in paper
        # TERM
        [(0, STDMASK, TERM)],

        # state 3 in paper
        # FACTOR
        [(0, STDMASK, FACTOR)],

        # State 4 in paper
        # OPEN PAREN
        [(0, STDMASK, OPENPAREN)],

        # State 5 in paper
        # INTEGER
        [(0, STDMASK, INTEGER)],

        # State 6 in paper
        # EXPRESSION             PLUS
        [(1, STDMASK, EXPRESSION), (0, STDMASK, ADDOP)],

        # State 7 in paper
        # TERM                     MULTIPLY
        [(1, STDMASK, TERM), (0, STDMASK, MULTOP)],

        # State 8 in paper
        # OPEN PAREN             EXPRESSION
        [(1, STDMASK, OPENPAREN), (0, STDMASK, EXPRESSION)],

        # State 9 in paper
        #  EXPRESSION             PLUS                 TERM
        [(2, STDMASK, EXPRESSION), (1, STDMASK, ADDOP), (0, STDMASK, TERM)],

        # State 10 in paper
        # TERM                    MULTIPLY             FACTOR
        [(2, STDMASK, TERM), (1, STDMASK, MULTOP), (0, STDMASK, FACTOR)],

        # State 11 in paper
        # OPEN PAREN             EXPRESSION             CLOSE PAREN
        [(2, STDMASK, OPENPAREN), (1, STDMASK, EXPRESSION), (0, STDMASK, CLOSEPAREN)]
        ]

        pairwise_priority_ruleset = [(1,8),(2,9), (3,10)]



        validitem_ruleset = [
        # For state 0:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 1:
        [(STDMASK, ADDOP),                                           (STDMASK, ENDOFPARSE)],
        # For state 2:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 3:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 4:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 5:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 6:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 7:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 8:
        [(STDMASK, ADDOP),                    (STDMASK, CLOSEPAREN)],
        # For state 9:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 10:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 11:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        ]


        forceshift_ruleset = [
        # For state 0:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 1:
        [(STDMASK, ADDOP),                                           (STDMASK, ENDOFPARSE)],
        # For state 2:
        [(STDMASK, MULTOP)],
        # For state 3:
        [],
        # For state 4:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 5:
        [],
        # For state 6:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 7:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 8:
        [(STDMASK, ADDOP),                    (STDMASK, CLOSEPAREN)],
        # For state 9:
        [(STDMASK, MULTOP)],
        # For state 10:
        [],
        # For state 11:
        []
        ]


        reduce_ruleset = [
        # For state 0:
        [],
        # For state 1:
        [],
        # For state 2:
        [((STDMASK, ADDOP),1), ((STDMASK, CLOSEPAREN),1), ((STDMASK, ENDOFPARSE),1)],
        # For state 3:
        [((STDMASK, ADDOP),3), ((STDMASK, MULTOP),3), ((STDMASK, CLOSEPAREN),3), ((STDMASK, ENDOFPARSE),3)],
        # For state 4:
        [],
        # For state 5:
        [((STDMASK, ADDOP),5), ((STDMASK, MULTOP),5), ((STDMASK, CLOSEPAREN),5), ((STDMASK, ENDOFPARSE),5)],
        # For state 6:
        [],
        # For state 7:
        [],
        # For state 8:
        [],
        # For state 9:
        [((STDMASK, ADDOP),0), ((STDMASK, CLOSEPAREN),0), ((STDMASK, ENDOFPARSE),0)],
        # For state 10:
        [((STDMASK, ADDOP),2), ((STDMASK, MULTOP),2), ((STDMASK, CLOSEPAREN),2), ((STDMASK, ENDOFPARSE),2)],
        # For state 11:
        [((STDMASK, ADDOP),4), ((STDMASK, MULTOP),4), ((STDMASK, CLOSEPAREN),4), ((STDMASK, ENDOFPARSE),4)],
        ]

        def extractor(x): return (x & 0x00ff)
        execute_rules = [
        (3, (lambda stackview: EXPRESSION )), #+ (extractor(stackview[0]) + extractor(stackview[2])))),

        (1, (lambda stackview: EXPRESSION)), # + extractor(stackview[0]))),

        (3, (lambda stackview: TERM )), #+ (extractor(stackview[0]) * extractor(stackview[2])))),

        (1, (lambda stackview: TERM )),# + extractor(stackview[0]))),

        (3, (lambda stackview: FACTOR )), #+ extractor(stackview[1]))),

        (1, (lambda stackview: FACTOR ))#+ extractor(stackview[0])))
        ]



        msm = MasterStateMachine(item_width=16, indices_width=16, stack_depth=16,
            validitem_ruleset = validitem_ruleset,
            pairwise_priority_ruleset = pairwise_priority_ruleset,
            forceshift_ruleset = forceshift_ruleset,
            reduce_ruleset=reduce_ruleset,
            execute_rules=execute_rules, stack_state_descriptions=common_stack_states,
            startofparse_marker = 0x5a00, serialized_tree_length=128, endofparse_marker=(0xffff, ENDOFPARSE))

        m.submodules.StateMachine = msm

        self.tapir = msm.tapir
        self.finalized = msm.parse_complete_out
        self.numwritten = msm.last_index_to_smem
        self.parse_success = msm.parse_success_out

        with m.If(resetted == 0):
            m.d.sync += resetted.eq(1)
        with m.If(resetted == 1):
            m.d.comb += msm.data_in_valid.eq(1)

        stall_recovery = Signal(1)

        cashew_register = Signal(16)
        m.d.sync += stall_recovery.eq(msm.data_in_ready) # one-cycle-delayed

        with m.If(msm.data_in_ready == 1):
            m.d.sync += cntr.eq(cntr + 1)
            m.d.comb += msm.data_in.eq(self.input_memory_data)

        with m.If((msm.data_in_ready == 0) & (stall_recovery == 1)):
            m.d.sync += cashew_register.eq(self.input_memory_data)
            m.d.comb += msm.data_in.eq(self.input_memory_data)

        with m.If(stall_recovery == 0):
            m.d.comb += msm.data_in.eq(cashew_register)


        m.d.comb += self.input_memory_addr.eq(cntr)



        return m


def run_the_sim(parse_me):
    m = Module()
    m.submodules.baka = nine = Cirno()
    trace = []
    numwritten = []
    parse_success = [0]


    def process():
        while True:
            z = yield nine.finalized
#            print(z)
            if(z==0):
                yield
                z = yield nine.finalized
#                print(z)

                array = []
                for idx in range(128):
                    #print(idx)
                    x = yield nine.tapir[idx]
                    array.append(x)
                trace.append(array)
            else:
                yield
                yield
                xz = yield nine.numwritten
                parse_success[0] = yield nine.parse_success
                numwritten.append(xz)
                print("NUM WRITTEN INSIDE", numwritten)
                print("PARSE SUCCESS,", parse_success)
                break


    with m.Switch(nine.input_memory_addr):
        for addr,data in enumerate(parse_me):
            with m.Case(addr):
                #print(addr,data)
                m.d.sync += nine.input_memory_data.eq(data)
        with m.Default():
                m.d.sync += nine.input_memory_data.eq(0xf00d)


    sim = Simulator(m)
    sim.add_clock(1e-9)
    sim.add_sync_process(process)
    with sim.write_vcd("test.vcd", "test.gtkw"):
        sim.run()

    for x in trace:
        print(x)
    print("XXXXXXXXXXXXXXXXXXXXX", numwritten, parse_success)

    if (parse_success[0] == 1):
        success = True
    else:
        success = False

    return (trace, numwritten, success)

top_bit = (1<<16)


def deserializer(serialized_array, last_idx_written):
    random_triggered = False
    list_of_nodes = []
    physical_to_logical = {}
    physical_idx = 0
    logical_idx  = 0
    while (physical_idx < last_idx_written + 1):
        new_element     = serialized_array[physical_idx]
        print("The element at the current index", physical_idx, "is:", hex(new_element))
        number_subnodes = serialized_array[physical_idx + 1]
        print("    WITH", number_subnodes, "SUBNODES")
        if (number_subnodes < 1):
            print("")
            print("MALFORMED AT IDX", physical_idx)
            print("TOTAL LEN OF ARRAY IS", len(serialized_array))
            break

        subnodes_array = []
        for sub_idx in range(physical_idx + 2, physical_idx + 2 + number_subnodes):
            if (serialized_array[sub_idx] & top_bit != 0): # index reference
                backreference_physical_index = serialized_array[sub_idx] - top_bit
                backreference_logical_index = physical_to_logical[backreference_physical_index]
                print("        BACKREFERENCE PHYS INDEX IS", backreference_physical_index)
                print("        BACKREFERENCE LOGICAL INDEX IS", backreference_logical_index)
                                #this_subnode = list_of_nodes[serialized_array]
                subnodes_array.append(list_of_nodes[backreference_logical_index])
            else:                                          # new subnode altogether
                print("        ABINITIO SUBNODE IS", serialized_array[sub_idx])
                subnodes_array.append(TreeNode(serialized_array[sub_idx],[]))


        print("        AND THESE SUBNODES ARE:")
        for idx,x in enumerate(subnodes_array):
            print("        SUBNODE NUMBER",idx)
            walk_the_tree(x, level=3)
            print("")
        print("")

        if (random.randint(0,last_idx_written * 32) == 0):
            false_element = random.choices(list_of_nonterminals + list_of_terminals)[0]

            if (false_element != new_element):
                print("RANDOM TRIGGERED! EXPECT FALSE RESULT!")
                random_triggered = True
                if(random.randint(0,100000) == 0):
                    print("RANDOM TRIGGERED BUT NOT REPORTED! AHAHAHA XXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXXX")
                    random_triggered = False
                new_element = false_element

        new_node = TreeNode(new_element, subnodes_array)

        list_of_nodes.append(new_node)

        physical_to_logical[physical_idx] = logical_idx
        physical_idx += 2 + number_subnodes
        logical_idx  += 1
    print("physical to logical mapping is", physical_to_logical)
    return (new_node, random_triggered)

def are_trees_equal(tree_one_root, tree_two_root): # returns True if equal, False otherwise
    if (tree_one_root.language_element != tree_two_root.language_element):
        return False
    if (len(tree_one_root.subnodes)    != len(tree_two_root.subnodes)):
        return False

    for idx, elem in enumerate(tree_one_root.subnodes):
        if(are_trees_equal(elem, tree_two_root.subnodes[idx]) == False):
            return False

    if ((len(tree_one_root.subnodes) == 0) and (len(tree_two_root.subnodes) == 0) and (tree_one_root.language_element == tree_two_root.language_element)):
        return True

    return True


tokens ={
"BOTTOM": 0x5a00,
"ENDOFPARSE": 0x5500,
"EXPRESSION": 0xEE00,
"TERM": 0xE700,
"FACTOR": 0xEF00,
"INTEGER": 0xE100,

"OPENPAREN": 0xCA00,
"CLOSEPAREN": 0xCB00,
"ADDOP"      : 0xAA00,
"MULTOP"      : 0xAB00,
"STDMASK"     : 0xff00
}

revmap = {v: k for k, v in tokens.items()}

print(revmap)

def hex_to_name(x):
    return revmap[x]


BOTTOM      = 0x5a00
ENDOFPARSE  = 0x5500

EXPRESSION  = 0xEE00
TERM        = 0xE700
FACTOR      = 0xEF00
INTEGER     = 0xE100

OPENPAREN   = 0xCA00
CLOSEPAREN  = 0xCB00

ADDOP       = 0xAA00
MULTOP      = 0xAB00
STDMASK     = 0xff00

rules = [
         (EXPRESSION, [EXPRESSION, ADDOP,      TERM      ]),
         (EXPRESSION, [TERM                                              ]),
         (TERM,       [TERM,       MULTOP,     FACTOR    ]),
         (TERM,       [FACTOR                                            ]),
         (FACTOR,     [OPENPAREN,  EXPRESSION, CLOSEPAREN]),
         (FACTOR,     [INTEGER                                           ])
         ]

list_of_nonterminals = [EXPRESSION, TERM, FACTOR]

list_of_terminals    = [INTEGER, ADDOP, MULTOP, OPENPAREN, CLOSEPAREN]



z = CFGBoltzmann.CFGBoltzmann(rules, list_of_nonterminals, list_of_terminals)
cooked_rules = z.preprocessor()

def do_an_iteration():
    bgen = z.Gzero_shimmed(EXPRESSION, random.randint(1,21))
    parse_me =  bgen[0]

    parse_me.append(ENDOFPARSE)
    
    (trace, numwritten, success) = run_the_sim(parse_me)

    if (success == False):
        return

    numwritten = numwritten[0] - 1
    print("NUM WRITTEN = ",numwritten)

    serialized_tree_final = trace[-1]
    print()
    print("DESER TREE FINAL is " '[{}]'.format(', '.join(hex(x) for x in serialized_tree_final)))


    (parser_output_tree, random_trig) = deserializer(serialized_tree_final, numwritten)

    walk_the_tree(parser_output_tree,0)

    #print("derivation tree was")
    #walk_the_tree(bgen[1])

    print("reversed derivation tree is")
    z.reversiflip(bgen[1])
    walk_the_tree(bgen[1])

    print("THE ORIGINAL STRING WAS", [hex_to_name(x) for x in parse_me])
    print("WAS RANDOM TRIGGERED??", random_trig)
    tree_equalityy = are_trees_equal(bgen[1], parser_output_tree)
    print("ARE TREES EQUAL???", tree_equalityy)
    if (random_trig == tree_equalityy):
        print("MASSIVE ERROR! MASSIVE ERROR!")
        exit(1)


for x in range(131072):
    do_an_iteration()