diff --git a/combinatorial_LR_parser.py b/combinatorial_LR_parser.py
index 5ac8ef2bab5aa7a372f584b410354751323e638e..3394106c902e5cd13713f6cacd005e15661d67db 100644
--- a/combinatorial_LR_parser.py
+++ b/combinatorial_LR_parser.py
@@ -558,8 +558,10 @@ class TreeSerializer(Elaboratable):
     def elaborate(self, platform):
         m = Module()
 
-        writepoint                  = Signal(self.mem_address_width) # where to write next
-        backpointer                 = Signal(self.mem_address_width) # pointer to the "number of children" field
+        start_of_record = Signal(self.mem_address_width) # start of next/yet-unwritten node record, advanced only
+                                                         # after each reduce
+
+        writepointer    = Signal(self.mem_address_width) # current pointer for in-progress writes
 
         mem = Memory(width=(self.item_width + 1), depth=32)
         m.submodules.parse_tree = wport = mem.write_port()
@@ -580,20 +582,33 @@ class TreeSerializer(Elaboratable):
 
 
         with m.FSM() as fsm:
+
+
+            with m.State("INITIALIZE"):
+
+                m.d.comb += self.ready_out.eq(0)
+
+
+                m.d.sync += writepointer.eq(2)
+                m.d.sync += start_of_record.eq(0)
+
+                m.next="NODE"
+
+
+
             with m.State("NODE"):
                 m.d.comb += self.ready_out.eq(1)
                 #m.d.sync += reduce_rule_number.eq(self.reduce_rule_number)
-                #m.d.sync += item_created_by_reduce_rule.eq(self.item_created_by_reduce_rule)
+                m.d.sync += item_created_by_reduce_rule.eq(self.item_created_by_reduce_rule)
 
                 with m.If(self.start_reduction == 1):
                     m.d.comb += self.memory_write_port.eq(self.destroyed_item_in)
-                    m.d.comb += self.memory_address_port.eq(writepoint)
+                    m.d.comb += self.memory_address_port.eq(writepointer)
                     m.d.comb += self.memory_write_enable.eq(1)
 
                     m.d.sync += number_of_children.eq(self.number_to_pop)
                     m.d.sync += number_remaining.eq(self.number_to_pop)
-                    m.d.sync += backpointer.eq(writepoint + 1)
-                    m.d.sync += writepoint.eq(writepoint + 2)
+                    m.d.sync += writepointer.eq(writepointer + 1)
                     m.next = "SUBNODES"
 
 
@@ -601,19 +616,27 @@ class TreeSerializer(Elaboratable):
                 m.d.comb += self.ready_out.eq(0)
                 with m.If(self.destroyed_item_valid_in == 1):
                     m.d.comb += self.memory_write_port.eq(self.destroyed_item_in)
-                    m.d.comb += self.memory_address_port.eq(writepoint)
+                    m.d.comb += self.memory_address_port.eq(writepointer)
                     m.d.comb += self.memory_write_enable.eq(1)
-                    m.d.sync += writepoint.eq(writepoint + 1)
 
+                    m.d.sync += writepointer.eq(writepointer + 1)
                     m.d.sync += number_remaining.eq(number_remaining - 1)
 
 
-                with m.If(number_remaining == 0):
+                with m.If(number_remaining == 1):
+                    m.d.comb += self.memory_write_port.eq(item_created_by_reduce_rule)
+                    m.d.comb += self.memory_address_port.eq(start_of_record)
+
+                    m.d.comb += self.memory_write_enable.eq(1)
+
+                    m.next = "FIXUP"
+
+            with m.State("FIXUP"):
                     m.d.comb += self.memory_write_port.eq(number_of_children)
-                    m.d.comb += self.memory_address_port.eq(backpointer)
+                    m.d.comb += self.memory_address_port.eq(start_of_record + 1)
                     m.d.comb += self.memory_write_enable.eq(1)
+                    m.d.sync += start_of_record.eq(writepointer + 1)
 
-                    m.next = "NODE"
 
 
         return m