diff --git a/rtl_lib/arbitrary_width_memory.py b/rtl_lib/arbitrary_width_memory.py index 953a74b23a6bc95b4f9d9085e913eb867fcccb80..5a13750a0ca08fd1c50c5d7cda0ab5b3e351eb35 100644 --- a/rtl_lib/arbitrary_width_memory.py +++ b/rtl_lib/arbitrary_width_memory.py @@ -31,30 +31,129 @@ class ArbitraryWidthMemoryBus(Record): super().__init__(ArbitraryWidthMemoryLayout(data_width=data_width, address_width=address_width)) +# initial_data must be folded, spindleed, and mutilated to fit inside the backing memory dimensions + +# later on we will write a function to take care of this automatically + class ArbitraryWidthMemory(Elaboratable): - def __init__(self, *, data_width, address_width): - self.data_width = data_width - self.address_width = address_width + def __init__(self, *, fake_data_width, fake_address_width, backing_memory_data_width, backing_memory_address_width, initial_data): + self.fake_data_width = fake_data_width + self.fake_address_width = fake_address_width + + def is_power_of_two(n): + return (n & (n-1) == 0) + + assert(is_power_of_two(backing_memory_data_width)) + + self.backing_memory_data_width = backing_memory_data_width + self.backing_memory_data_width_bits = backing_memory_data_width.bit_length() - 1 + assert(2**self.backing_memory_data_width_bits == backing_memory_data_width) + + self.backing_memory_address_width = backing_memory_address_width - self.bus = ArbitraryWidthMemoryBus(data_width=data_width, address_width=data_width) + self.backing_memory_length = 256 #len(initial_data) + + assert(2**backing_memory_address_width >= self.backing_memory_length) + + self.bus = ArbitraryWidthMemoryBus(data_width=fake_data_width, address_width=fake_address_width) + self.backing_memory = Memory(width=backing_memory_data_width, depth=self.backing_memory_length, init=initial_data) def elaborate(self, platform): m = Module() + bus = self.bus + + # regs to persist output of last read onto the output bus until it's accepted by the downstream consumer + last_r_data = Signal(self.fake_data_width) + last_r_data_valid = Signal(1) + + + + unwrapped_bit_index = Signal(range(self.backing_memory_length*self.backing_memory_data_width)) #probably cannot be made shorter. + + left_bit_index = Signal(self.backing_memory_data_width) + right_bit_index = Signal(self.backing_memory_data_width) + end_bit_pseudo_index = Signal(range(self.backing_memory_length*self.backing_memory_data_width)) # can be made shorter + additional_words = Signal(self.backing_memory_address_width) # can also be amde shorter + starting_word = Signal(self.backing_memory_address_width) + + + + # State across clock cycles + next_address = Signal(self.backing_memory_address_width) + additional_words_regd = Signal(self.backing_memory_address_width) + + # pseudo-output signals, we connect these to memory and a shift register later: + valid_fetch = Signal(1) + fetch_address = Signal(self.backing_memory_address_width) + + + m.submodules.read_port = read_port = self.backing_memory.read_port() + m.d.comb += read_port.addr.eq(fetch_address) + m.d.comb += self.bus.r_data.eq(read_port.data) + + with m.FSM() as fsm: + with m.State("RESET"): + m.next ="READY" + + with m.State("READY"): + m.d.comb += bus.ready_out.eq(1) + + with m.If(last_r_data_valid == 1): + m.d.comb += bus.valid_out.eq(1) + m.d.comb += bus.r_data.eq(last_r_data) + + with m.If(bus.valid_in == 1): + # the index and bit-index computation goes as follows: + + # 1) We calculate an unwrapped bit index by multiplying the fake index by the fake data width + m.d.comb += unwrapped_bit_index.eq(bus.r_addr * self.fake_data_width) + m.d.comb += starting_word.eq(unwrapped_bit_index[self.backing_memory_data_width_bits:]) + m.d.comb += fetch_address.eq(starting_word) + m.d.comb += left_bit_index.eq(unwrapped_bit_index[:self.backing_memory_data_width_bits]) + m.d.comb += end_bit_pseudo_index.eq(left_bit_index + self.fake_data_width-1) + m.d.comb += additional_words.eq(end_bit_pseudo_index[self.backing_memory_data_width_bits:]) + + with m.If(additional_words == 0): + m.d.comb += right_bit_index.eq(end_bit_pseudo_index[:self.backing_memory_data_width_bits]) + with m.Else(): + m.d.comb += right_bit_index.eq(self.backing_memory_data_width-1) + m.d.sync += next_address.eq(fetch_address + 1) + m.d.sync += additional_words_regd.eq(additional_words) + m.next="ADD" + + + with m.State("ADD"): + # we handle both the full-word fetches and the final (potentially partial word) fetch here + with m.If(additional_words_regd == 1): # special case, we may not have to include the whole word! + m.d.comb += left_bit_index.eq(0) + m.d.comb += right_bit_index.eq(end_bit_pseudo_index[:self.backing_memory_data_width_bits]) + m.d.comb += fetch_address.eq(next_address) + m.next = "STALL" + with m.Else(): # non-special case, fetch the whole word. + m.d.comb += left_bit_index.eq(0) + m.d.comb += right_bit_index.eq(self.backing_memory_data_width-1) + m.d.sync += next_address.eq(next_address + 1) + m.d.comb += fetch_address.eq(next_address) + m.d.sync += additional_words_regd.eq(additional_words_regd - 1) + with m.State("STALL"): + m.next="STALL" return m # This is non-synthesizable but is intended to provide a model for formal verification. +# We initialize the "sim_memory" with the same data that's in the + class GoldenArbitraryWidthMemory(Elaboratable): def __init__(self, *, data_width, address_width, sim_memory): self.data_width = data_width self.address_width = address_width self.sim_memory_size = len(sim_memory) - self.memory = Memory(width=data_width, depth=self.sim_memory_size, init=sim_memory) + self.fake_memory = Memory(width=data_width, depth=self.sim_memory_size, init=sim_memory) self.bus = ArbitraryWidthMemoryBus(data_width=data_width, address_width=address_width) @@ -62,22 +161,15 @@ class GoldenArbitraryWidthMemory(Elaboratable): m = Module() write_ptr = Signal(range(self.sim_memory_size)) - read_ptr = Signal(range(self.sim_memory_size)) - m.submodules.read_port = read_port = self.memory.read_port() - m.submodules.write_port = write_port = self.memory.write_port() + m.submodules.read_port = read_port = self.fake_memory.read_port() bus = self.bus with m.If(bus.valid_in == 1): - with m.If(bus.write_enable == 1): - m.d.comb += write_port.en.eq(1) - m.d.comb += write_port.addr.eq(bus.w_addr) - m.d.comb += write_port.data.eq(bus.w_data) - with m.Else(): - #m.d.comb += read_port.en.eq(0) - m.d.comb += read_port.addr.eq(bus.r_addr) - m.d.comb += bus.r_data.eq(read_port.data) - m.d.sync += bus.valid_out.eq(1) + #m.d.comb += read_port.en.eq(0) + m.d.comb += read_port.addr.eq(bus.r_addr) + m.d.comb += bus.r_data.eq(read_port.data) + m.d.sync += bus.valid_out.eq(1) return m @@ -93,18 +185,17 @@ class DummyPlug(Elaboratable): def elaborate(self, platform): m = Module() - m.submodules.AWMem = AWMem = GoldenArbitraryWidthMemory(data_width=7, address_width=8, sim_memory=[0x40,0x41,0x42, 0x43, 0x44, 0x45,0x46,0x47]) - counter = Signal(8) - + m.submodules.FakeAWMem = FakeAWMem = ArbitraryWidthMemory(fake_data_width=16, + fake_address_width=8, initial_data=[0x23,0x45, 0x67, 0x89, 0xab,0xcd,0xef], + backing_memory_data_width=8, backing_memory_address_width=8) + counter = Signal(8, reset=1) - m.d.sync += counter.eq(counter+1) +# with m.If(FakeAWMem.bus.ready_out == 1): + #m.d.sync += counter.eq(counter+1) #with m.If(counter == 4): - m.d.comb += AWMem.bus.valid_in.eq(1) - m.d.comb += AWMem.bus.write_enable.eq(0) - m.d.comb += AWMem.bus.r_addr.eq(counter) - #m.d.comb += AWMem.bus.w_addr.eq(counter) - #m.d.comb += AWMem.bus.w_data.eq(counter) + m.d.comb += FakeAWMem.bus.valid_in.eq(1) + m.d.comb += FakeAWMem.bus.r_addr.eq(counter) return m