-
Kia authored
simple_parse_tree.py 11.03 KiB
from nmigen import *
from nmigen.hdl.rec import *
from nmigen.cli import main
# An example parse tree looks like:
# nonterminal A
# ** | **
# ** | **
# ** | **
# * | *
# nonterminal B terminal K nonterminal D
# / * *
# / * *
# terminal I nonterminal C terminal L
# \
# \
# terminal J
# As it is parsed, we serialize it on the fly. The serialization scheme represents the tree
# as a sequence of records, each of which corresponds to a reduction. In other words, each
# nonterminal appearing in the parse tree will have a corresponding record.
#
# The list of records, in order, is therefore as follows:
#
# nonterminal C -> terminal J
# nonterminal B -> terminal I, nonterminal C
# nonterminal D -> terminal L
# nonterminal A -> nonterminal B, terminal K, nonterminal D
# We note that for every nonterminal on the right side of a record, that nonterminal has already
# appeared on the left side of a previous record. This is because of how a bottom-up parse proceeds.
# Each serialized parse tree record looks like this:
#
# FIELD 0 FIELD 1 FIELD 2 FIELD 3 FIELD N FIELD N+1
# [Nonterminal symbol] [N = number of subnodes] [tag bit 1, subnode 1] [tag bit 2, subnode 2] ... [tag bit N-1, subnode N-1] [tag bit n, subnode N]
# Field 0 indicates the symbol number of the nonterminal created by the reduction. Its bitwidth is
# log_2(number of nonterminals).
# The number of subnodes, N, is represented as an unsigned integer. To determine its bitwidth, we
# examine all the language rules, figure out which is the rule with the most symbols
# (terminal and nonterminal alike) on its right-hand-side, and take the log_2 of that.
# bitwidth(N) = log_2(max(len(RHS of rule 1), len(RHS of rule 2), len(RHS of rule 3), ... )))
# There are N slots for subnodes following N, and each slot contains a tag bit and a subnode identifier:
# If the subnode is a terminal, the symbol/token number of that terminal is placed in the corresponding
# slot in the record, and the tag bit is set to zero.
# If the subnode is a nonterminal, the *index* of the record that *created* that nonterminal is placed in
# the corresponding slot in the record, and the tag bit is set to one.
# Therefore, the length of each subnode slot is:
# 1 + log_2(max(maximum possible number of emitted records per parse, number of terminals))
# For the above example parse tree, the emitted records would be:
# record index: concise form verbose form
# record 0: [C] [1] [0 J] [nonterminal C] [1 subnode ] [terminal J]
# record 1: [B] [2] [0 I] [1 0] [nonterminal B] [2 subnodes] [terminal I] [nonterminal at record 0]
# record 2: [D] [1] [0 L] [nonterminal D] [1 subnode ] [terminal L]
# record 3: [A] [3] [1 1] [0 K] [1 2] [nonterminal A] [3 subnodes] [nonterminal at record 1] [terminal K] [nonterminal at record 2]
# It is possible to deterministically construct the corresponding parse tree from these records.
# This serialization scheme has the advantage of not requiring read access to the
# serialization memory, which allows for streaming of the parse tree to another component
# on/outside the FPGA. In order to accomplish this, we maintain a "sideband" stack (managed
# in parallel with the primary parse stack) where indices for nonterminals are kept.
# These indices allow "stitching" of the individual records into a coherent parse tree.
# Every edge in the above parse tree that is represented with "*" is an edge that needs to
# be stitched in this way.
# The parameters for instantiating the tree serializer are as follows:
#
# parameters dependent only on the language:
# * number of nonterminals
# * number of terminals
# * list of language rules
#
# parameters dependent on the use case:
# * bitwidth of parse tree memory
# * length of parse tree memory
# * maximum possible number of emitted records per parse
#
# If the latter is not provided, a pessimistic estimate will be calculated:
# (number of bits in memory) / (number of bits to represent the smallest possible record)
# In general, the bitwidth of the record fields do not correspond to the bitwidth of memory available
# to the FPGA. In order to address this, we split up the records across multiple memory addresses as
# is necessary. This means that the records of the serialized parse tree won't necessarily be in any
# sort of alignment. If it is desired, we can add functionality to do alignment. But I do not think it
# is a big deal since this data structure anyway needs to be decoded from the start -- there's no random
# access possible.
class TreeSerializer(Elaboratable):
def __init__(self, *, item_width, indices_width, stack_depth, serialized_tree_length):
# Parameters
self.item_width = item_width
self.stack_depth = stack_depth
self.indices_width = indices_width
self.mem_address_width = indices_width
self.serialized_tree_length = serialized_tree_length
# inputs
# Activated only once per reduction
self.start_reduction = Signal(1)
# When ^ goes high, the three below are registered
self.number_to_pop = Signal(range(stack_depth))
self.reduce_rule_number = Signal(item_width)
self.item_created_by_reduce_rule = Signal(item_width)
# Varies with every popped item
self.destroyed_item_valid_in = Signal(1)
self.destroyed_item_in = Signal(item_width) # off main stack
self.destroyed_item_index_in = Signal(indices_width) # off side stack
# outputs
# output to state machine
self.ready_out = Signal(1)
self.internal_fault = Signal(1)
self.serialized_index = Signal(indices_width) # push *this* onto side stack for the
# newly created item
# interface with serialization memory
self.memory_write_port = Signal(item_width + 1)
self.memory_address_port = Signal(self.mem_address_width)
self.memory_write_enable = Signal(1)
self.mem = Memory(width=(self.item_width + 1), depth=serialized_tree_length)
def elaborate(self, platform):
m = Module()
start_of_record = Signal(self.mem_address_width) # start of next/yet-unwritten node record, advanced only
# after each reduce
m.submodules.parse_tree = wport = (self.mem).write_port()
m.d.comb += wport.en.eq(self.memory_write_enable),
m.d.comb += wport.addr.eq(self.memory_address_port),
m.d.comb += wport.data.eq(self.memory_write_port)
# Per-reduce registered signals:
number_of_children = Signal(range(self.stack_depth))
reduce_rule_number = Signal(self.item_width)
item_created_by_reduce_rule = Signal(self.item_width)
# incremented each cycle
number_written = Signal(range(self.stack_depth))
m.d.comb += self.serialized_index.eq(start_of_record)
with m.FSM() as fsm:
with m.State("INITIALIZE"):
m.d.comb += self.ready_out.eq(0)
m.d.comb += self.internal_fault.eq(0)
m.d.sync += start_of_record.eq(0)
m.d.sync += number_written.eq(0)
m.next="NODE"
with m.State("NODE"):
m.d.comb += self.ready_out.eq(1)
m.d.comb += self.internal_fault.eq(0)
#m.d.sync += reduce_rule_number.eq(self.reduce_rule_number)
m.d.sync += item_created_by_reduce_rule.eq(self.item_created_by_reduce_rule)
with m.If(self.start_reduction == 1):
with m.If(self.destroyed_item_index_in == 0):
m.d.comb += self.memory_write_port.eq(self.destroyed_item_in)
with m.Else():
m.d.comb += self.memory_write_port.eq(self.destroyed_item_index_in)
m.d.comb += self.memory_address_port.eq(start_of_record + 2)
m.d.comb += self.memory_write_enable.eq(1)
m.d.sync += number_of_children.eq(self.number_to_pop)
m.d.sync += number_written.eq(1)
m.next = "SUBNODES"
with m.If(self.memory_address_port > (self.serialized_tree_length - 1)):
m.next = "ABORT"
with m.State("SUBNODES"):
m.d.comb += self.ready_out.eq(0)
m.d.comb += self.internal_fault.eq(0)
with m.If(self.destroyed_item_valid_in == 1):
with m.If(self.destroyed_item_index_in == 0):
m.d.comb += self.memory_write_port.eq(self.destroyed_item_in)
with m.Else():
m.d.comb += self.memory_write_port.eq(self.destroyed_item_index_in)
m.d.comb += self.memory_address_port.eq(start_of_record + 2 + number_written)
m.d.comb += self.memory_write_enable.eq(1)
m.d.sync += number_written.eq(number_written + 1)
with m.If(number_written == number_of_children):
m.d.comb += self.memory_write_port.eq(item_created_by_reduce_rule)
m.d.comb += self.memory_address_port.eq(start_of_record)
m.d.comb += self.memory_write_enable.eq(1)
m.next = "FIXUP"
with m.If(self.memory_address_port > (self.serialized_tree_length - 1)):
m.next = "ABORT"
with m.State("FIXUP"):
m.d.comb += self.ready_out.eq(0)
m.d.comb += self.internal_fault.eq(0)
m.d.comb += self.memory_write_port.eq(number_of_children)
m.d.comb += self.memory_address_port.eq(start_of_record + 1)
m.d.comb += self.memory_write_enable.eq(1)
m.d.sync += start_of_record.eq(start_of_record + 2 + number_of_children)
m.next = "NODE"
with m.If(self.memory_address_port > (self.serialized_tree_length - 1)):
m.next = "ABORT"
with m.State("ABORT"):
m.d.comb += self.ready_out.eq(0)
m.d.comb += self.internal_fault.eq(1)