from nmigen import *
from nmigen.cli import main

from nmigen.back.pysim import Simulator, Delay

from functools import reduce


from typing import List, Dict, Tuple, Optional


from combinatorial_LR_parser import MasterStateMachine


class Cirno(Elaboratable):
    def __init__(self):
        self.input_memory_addr = Signal(8)
        self.input_memory_data = Signal(16)

    def elaborate(self, platform):
        cntr = Signal(8)
        resetted = Signal(1)

        valid_data_out = Signal(1)
        romem_ready = Signal(1)
        data_port = Signal(16)


        BOTTOM      = 0x5a00
        ENDOFPARSE  = 0x5500

        EXPRESSION  = 0xEE00
        TERM        = 0xE700
        FACTOR      = 0xEF00
        INTEGER     = 0xE100

        OPENPAREN   = 0xCA00
        CLOSEPAREN  = 0xCB00

        ADDOP       = 0xAA00
        MULTOP      = 0xAB00
        STDMASK     = 0xff00

        m = Module()
        # BOTTOM = start of parse



        common_stack_states = [
        # State 0 in the paper
        # Bottom of parse stack
        [(0, STDMASK, BOTTOM)],

        # State 1 in the paper
        # Bottom of parse stack        Expression
        [(1, STDMASK, BOTTOM),          (0, STDMASK, EXPRESSION)],

        # state 2 in paper
        # TERM
        [(0, STDMASK, TERM)],

        # state 3 in paper
        # FACTOR
        [(0, STDMASK, FACTOR)],

        # State 4 in paper
        # OPEN PAREN
        [(0, STDMASK, OPENPAREN)],

        # State 5 in paper
        # INTEGER
        [(0, STDMASK, INTEGER)],

        # State 6 in paper
        # EXPRESSION             PLUS
        [(1, STDMASK, EXPRESSION), (0, STDMASK, ADDOP)],

        # State 7 in paper
        # TERM                     MULTIPLY
        [(1, STDMASK, TERM), (0, STDMASK, MULTOP)],

        # State 8 in paper
        # OPEN PAREN             EXPRESSION
        [(1, STDMASK, OPENPAREN), (0, STDMASK, EXPRESSION)],

        # State 9 in paper
        #  EXPRESSION             PLUS                 TERM
        [(2, STDMASK, EXPRESSION), (1, STDMASK, ADDOP), (0, STDMASK, TERM)],

        # State 10 in paper
        # TERM                    MULTIPLY             FACTOR
        [(2, STDMASK, TERM), (1, STDMASK, MULTOP), (0, STDMASK, FACTOR)],

        # State 11 in paper
        # OPEN PAREN             EXPRESSION             CLOSE PAREN
        [(2, STDMASK, OPENPAREN), (1, STDMASK, EXPRESSION), (0, STDMASK, CLOSEPAREN)]
        ]

        pairwise_priority_ruleset = [(1,8),(2,9), (3,10)]



        validitem_ruleset = [
        # For state 0:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 1:
        [(STDMASK, ADDOP),                                           (STDMASK, ENDOFPARSE)],
        # For state 2:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 3:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 4:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 5:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 6:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 7:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 8:
        [(STDMASK, ADDOP),                    (STDMASK, CLOSEPAREN)],
        # For state 9:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 10:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        # For state 11:
        [(STDMASK, ADDOP), (STDMASK, MULTOP), (STDMASK, CLOSEPAREN), (STDMASK, ENDOFPARSE)],
        ]


        forceshift_ruleset = [
        # For state 0:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 1:
        [(STDMASK, ADDOP),                                           (STDMASK, ENDOFPARSE)],
        # For state 2:
        [(STDMASK, MULTOP)],
        # For state 3:
        [],
        # For state 4:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 5:
        [],
        # For state 6:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 7:
        [(STDMASK, INTEGER),                  (STDMASK, OPENPAREN)],
        # For state 8:
        [(STDMASK, ADDOP),                    (STDMASK, CLOSEPAREN)],
        # For state 9:
        [(STDMASK, MULTOP)],
        # For state 10:
        [],
        # For state 11:
        []
        ]


        reduce_ruleset = [
        # For state 0:
        [],
        # For state 1:
        [],
        # For state 2:
        [((STDMASK, ADDOP),1), ((STDMASK, CLOSEPAREN),1), ((STDMASK, ENDOFPARSE),1)],
        # For state 3:
        [((STDMASK, ADDOP),3), ((STDMASK, MULTOP),3), ((STDMASK, CLOSEPAREN),3), ((STDMASK, ENDOFPARSE),3)],
        # For state 4:
        [],
        # For state 5:
        [((STDMASK, ADDOP),5), ((STDMASK, MULTOP),5), ((STDMASK, CLOSEPAREN),5), ((STDMASK, ENDOFPARSE),5)],
        # For state 6:
        [],
        # For state 7:
        [],
        # For state 8:
        [],
        # For state 9:
        [((STDMASK, ADDOP),0), ((STDMASK, CLOSEPAREN),0), ((STDMASK, ENDOFPARSE),0)],
        # For state 10:
        [((STDMASK, ADDOP),2), ((STDMASK, MULTOP),2), ((STDMASK, CLOSEPAREN),2), ((STDMASK, ENDOFPARSE),2)],
        # For state 11:
        [((STDMASK, ADDOP),4), ((STDMASK, MULTOP),4), ((STDMASK, CLOSEPAREN),4), ((STDMASK, ENDOFPARSE),4)],
        ]

        def extractor(x): return (x & 0x00ff)
        execute_rules = [
        (3, (lambda stackview: EXPRESSION + (extractor(stackview[0]) + extractor(stackview[2])))),

        (1, (lambda stackview: EXPRESSION + extractor(stackview[0]))),

        (3, (lambda stackview: TERM + (extractor(stackview[0]) * extractor(stackview[2])))),

        (1, (lambda stackview: TERM + extractor(stackview[0]))),

        (3, (lambda stackview: FACTOR + extractor(stackview[1]))),

        (1, (lambda stackview: FACTOR + extractor(stackview[0])))
        ]

        msm = MasterStateMachine(item_width=16, indices_width=16, stack_depth=16,
            validitem_ruleset = validitem_ruleset,
            pairwise_priority_ruleset = pairwise_priority_ruleset,
            forceshift_ruleset = forceshift_ruleset,
            reduce_ruleset=reduce_ruleset,
            execute_rules=execute_rules, stack_state_descriptions=common_stack_states, endofparse_marker=(0xffff, ENDOFPARSE))

        m.submodules.StateMachine = msm

        self.tapir = msm.tapir
        self.finalized = msm.parse_complete_out

        with m.If(resetted == 0):
            m.d.sync += resetted.eq(1)
        with m.If(resetted == 1):
            m.d.comb += msm.data_in_valid.eq(1)

        stall_recovery = Signal(1)

        cashew_register = Signal(16)
        m.d.sync += stall_recovery.eq(msm.data_in_ready) # one-cycle-delayed

        with m.If(msm.data_in_ready == 1):
            m.d.sync += cntr.eq(cntr + 1)
            m.d.comb += msm.data_in.eq(self.input_memory_data)

        with m.If((msm.data_in_ready == 0) & (stall_recovery == 1)):
            m.d.sync += cashew_register.eq(self.input_memory_data)
            m.d.comb += msm.data_in.eq(self.input_memory_data)

        with m.If(stall_recovery == 0):
            m.d.comb += msm.data_in.eq(cashew_register)


        m.d.comb += self.input_memory_addr.eq(cntr)



        return m



BOTTOM      = 0x5a00
ENDOFPARSE  = 0x5500

EXPRESSION  = 0xEE00
TERM        = 0xE700
FACTOR      = 0xEF00
INTEGER     = 0xE100

OPENPAREN   = 0xCA00
CLOSEPAREN  = 0xCB00

ADDOP       = 0xAA00
MULTOP      = 0xAB00
STDMASK     = 0xff00


if __name__ == '__main__':
    m = Module()
    m.submodules.baka = nine = Cirno()

    trace = []
    def process():
        while True:
            z = yield nine.finalized
            print(z)
            if(z==0):
                yield
                z = yield nine.finalized
                print(z)

                array = []
                for idx in range(64):
                    #print(idx)
                    x = yield nine.tapir[idx]
                    array.append(x)
                trace.append(array)
            else:
                break


    init_data = [OPENPAREN, OPENPAREN, 0XE102, CLOSEPAREN, ADDOP, 0XE103, CLOSEPAREN, ADDOP, 0XE101, ENDOFPARSE]
    
    with m.Switch(nine.input_memory_addr):
        for addr,data in enumerate(init_data):
            with m.Case(addr):
                #print(addr,data)
                m.d.sync += nine.input_memory_data.eq(data)
        with m.Default():
                m.d.sync += nine.input_memory_data.eq(0xf00d)


    sim = Simulator(m)
    sim.add_clock(1e-9)
    sim.add_sync_process(process)
    with sim.write_vcd("test.vcd", "test.gtkw"):
        sim.run()

    for x in trace:
        print(x)



class TreeNode:
    def __init__(self, language_element, subnodes):
        self.language_element = language_element
        self.subnodes = subnodes

    def __str__(self):
        return str(self.language_element)

x = TreeNode("abcd", [])

print(x)
print(x.subnodes)

x.subnodes=[TreeNode("defg",[423])]

print(x)
print(x.subnodes)
print(x.subnodes[0].subnodes)