Skip to content
Snippets Groups Projects
simple_parse_tree.py 11.03 KiB
from nmigen import *
from nmigen.hdl.rec import *
from nmigen.cli import main


# An example parse tree looks like:

#                            nonterminal A                          
#                          **      |   **                           
#                       **         |      **                        
#                    **            |         **                     
#                  *               |           *                    
#            nonterminal B    terminal K    nonterminal D           
#           /          *                                   *        
#          /            *                                   *       
#    terminal I       nonterminal C                      terminal L 
#                        \                                          
#                         \                                         
#                       terminal J                                  


# As it is parsed, we serialize it on the fly. The serialization scheme represents the tree
# as a sequence of records, each of which corresponds to a reduction. In other words, each
# nonterminal appearing in the parse tree will have a corresponding record.
#
# The list of records, in order, is therefore as follows:
#
# nonterminal C -> terminal J
# nonterminal B -> terminal I, nonterminal C
# nonterminal D -> terminal L
# nonterminal A -> nonterminal B, terminal K, nonterminal D

# We note that for every nonterminal on the right side of a record, that nonterminal has already
# appeared on the left side of a previous record. This is because of how a bottom-up parse proceeds.

# Each serialized parse tree record looks like this:
#
#        FIELD 0               FIELD 1                FIELD 2                 FIELD 3                      FIELD N                 FIELD N+1
# [Nonterminal symbol] [N = number of subnodes] [tag bit 1, subnode 1] [tag bit 2, subnode 2] ... [tag bit N-1, subnode N-1] [tag bit n, subnode N]

# Field 0 indicates the symbol number of the nonterminal created by the reduction. Its bitwidth is
# log_2(number of nonterminals).

# The number of subnodes, N, is represented as an unsigned integer. To determine its bitwidth, we
# examine all the language rules, figure out which is the rule with the most symbols
# (terminal and nonterminal alike) on its right-hand-side, and take the log_2 of that.

# bitwidth(N) = log_2(max(len(RHS of rule 1), len(RHS of rule 2), len(RHS of rule 3), ... )))

# There are N slots for subnodes following N, and each slot contains a tag bit and a subnode identifier:

# If the subnode is a terminal, the symbol/token number of that terminal is placed in the corresponding
# slot in the record, and the tag bit is set to zero.

# If the subnode is a nonterminal, the *index* of the record that *created* that nonterminal is placed in
# the corresponding slot in the record, and the tag bit is set to one.

# Therefore, the length of each subnode slot is:
# 1 + log_2(max(maximum possible number of emitted records per parse, number of terminals))


# For the above example parse tree, the emitted records would be:

# record index: concise form                 verbose form
# record 0:     [C] [1] [0 J]                [nonterminal C] [1 subnode ] [terminal J]
# record 1:     [B] [2] [0 I] [1 0]          [nonterminal B] [2 subnodes] [terminal I] [nonterminal at record 0]
# record 2:     [D] [1] [0 L]                [nonterminal D] [1 subnode ] [terminal L]
# record 3:     [A] [3] [1 1] [0 K] [1 2]    [nonterminal A] [3 subnodes] [nonterminal at record 1] [terminal K] [nonterminal at record 2]

# It is possible to deterministically construct the corresponding parse tree from these records.

# This serialization scheme has the advantage of not requiring read access to the
# serialization memory, which allows for streaming of the parse tree to another component
# on/outside the FPGA. In order to accomplish this, we maintain a "sideband" stack (managed
# in parallel with the primary parse stack) where indices for nonterminals are kept.
# These indices allow "stitching" of the individual records into a coherent parse tree.

# Every edge in the above parse tree that is represented with "*" is an edge that needs to
# be stitched in this way.


# The parameters for instantiating the tree serializer are as follows:
#
# parameters dependent only on the language:
#     * number of nonterminals
#     * number of terminals
#     * list of language rules
#
# parameters dependent on the use case:
#    * bitwidth of parse tree memory
#    * length of parse tree memory
#    * maximum possible number of emitted records per parse
#
# If the latter is not provided, a pessimistic estimate will be calculated:
# (number of bits in memory) / (number of bits to represent the smallest possible record)

# In general, the bitwidth of the record fields do not correspond to the bitwidth of memory available
# to the FPGA. In order to address this, we split up the records across multiple memory addresses as
# is necessary. This means that the records of the serialized parse tree won't necessarily be in any
# sort of alignment. If it is desired, we can add functionality to do alignment. But I do not think it
# is a big deal since this data structure anyway needs to be decoded from the start -- there's no random
# access possible.






class TreeSerializer(Elaboratable):
    def __init__(self, *, item_width, indices_width, stack_depth, serialized_tree_length):
        # Parameters
        self.item_width    = item_width
        self.stack_depth   = stack_depth
        self.indices_width = indices_width
        self.mem_address_width = indices_width
        self.serialized_tree_length = serialized_tree_length

        # inputs

        # Activated only once per reduction
        self.start_reduction             = Signal(1)
        # When ^ goes high, the three below are registered

        self.number_to_pop               = Signal(range(stack_depth))
        self.reduce_rule_number          = Signal(item_width)
        self.item_created_by_reduce_rule = Signal(item_width)

        # Varies with every popped item
        self.destroyed_item_valid_in        = Signal(1)
        self.destroyed_item_in              = Signal(item_width)     # off main stack
        self.destroyed_item_index_in        = Signal(indices_width)  # off side stack

        # outputs
        # output to state machine
        self.ready_out                   = Signal(1)
        self.internal_fault              = Signal(1)

        self.serialized_index            = Signal(indices_width)    # push *this* onto side stack for the
                                                                    # newly created item
        # interface with serialization memory
        self.memory_write_port           = Signal(item_width + 1)
        self.memory_address_port         = Signal(self.mem_address_width)
        self.memory_write_enable         = Signal(1)


        self.mem = Memory(width=(self.item_width + 1), depth=serialized_tree_length)

    def elaborate(self, platform):
        m = Module()

        start_of_record = Signal(self.mem_address_width) # start of next/yet-unwritten node record, advanced only
                                                         # after each reduce

        m.submodules.parse_tree = wport = (self.mem).write_port()



        m.d.comb += wport.en.eq(self.memory_write_enable),
        m.d.comb += wport.addr.eq(self.memory_address_port),
        m.d.comb += wport.data.eq(self.memory_write_port)


        # Per-reduce registered signals:
        number_of_children          = Signal(range(self.stack_depth))

        reduce_rule_number          = Signal(self.item_width)
        item_created_by_reduce_rule = Signal(self.item_width)

        # incremented each cycle
        number_written              = Signal(range(self.stack_depth))
        m.d.comb += self.serialized_index.eq(start_of_record)


        with m.FSM() as fsm:
            with m.State("INITIALIZE"):
                m.d.comb += self.ready_out.eq(0)
                m.d.comb += self.internal_fault.eq(0)

                m.d.sync += start_of_record.eq(0)
                m.d.sync += number_written.eq(0)

                m.next="NODE"


            with m.State("NODE"):
                m.d.comb += self.ready_out.eq(1)
                m.d.comb += self.internal_fault.eq(0)

                #m.d.sync += reduce_rule_number.eq(self.reduce_rule_number)
                m.d.sync += item_created_by_reduce_rule.eq(self.item_created_by_reduce_rule)

                with m.If(self.start_reduction == 1):
                    with m.If(self.destroyed_item_index_in == 0):
                        m.d.comb += self.memory_write_port.eq(self.destroyed_item_in)
                    with m.Else():
                        m.d.comb += self.memory_write_port.eq(self.destroyed_item_index_in)
                    m.d.comb += self.memory_address_port.eq(start_of_record + 2)
                    m.d.comb += self.memory_write_enable.eq(1)

                    m.d.sync += number_of_children.eq(self.number_to_pop)
                    m.d.sync += number_written.eq(1)

                    m.next = "SUBNODES"

                with m.If(self.memory_address_port > (self.serialized_tree_length - 1)):
                    m.next = "ABORT"

            with m.State("SUBNODES"):
                m.d.comb += self.ready_out.eq(0)
                m.d.comb += self.internal_fault.eq(0)

                with m.If(self.destroyed_item_valid_in == 1):
                    with m.If(self.destroyed_item_index_in == 0):
                        m.d.comb += self.memory_write_port.eq(self.destroyed_item_in)
                    with m.Else():
                        m.d.comb += self.memory_write_port.eq(self.destroyed_item_index_in)
                    m.d.comb += self.memory_address_port.eq(start_of_record + 2 + number_written)
                    m.d.comb += self.memory_write_enable.eq(1)

                    m.d.sync += number_written.eq(number_written + 1)


                with m.If(number_written == number_of_children):
                    m.d.comb += self.memory_write_port.eq(item_created_by_reduce_rule)
                    m.d.comb += self.memory_address_port.eq(start_of_record)
                    m.d.comb += self.memory_write_enable.eq(1)

                    m.next = "FIXUP"

                with m.If(self.memory_address_port > (self.serialized_tree_length - 1)):
                    m.next = "ABORT"

            with m.State("FIXUP"):
                    m.d.comb += self.ready_out.eq(0)
                    m.d.comb += self.internal_fault.eq(0)

                    m.d.comb += self.memory_write_port.eq(number_of_children)
                    m.d.comb += self.memory_address_port.eq(start_of_record + 1)
                    m.d.comb += self.memory_write_enable.eq(1)
                    m.d.sync += start_of_record.eq(start_of_record + 2 + number_of_children)

                    m.next = "NODE"

                    with m.If(self.memory_address_port > (self.serialized_tree_length - 1)):
                        m.next = "ABORT"

            with m.State("ABORT"):
                m.d.comb += self.ready_out.eq(0)
                m.d.comb += self.internal_fault.eq(1)