Justin Chou commited on
Commit
679abc4
·
1 Parent(s): e9c2b75
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. Dockerfile +14 -4
  2. demo/app.py +0 -312
  3. demo/readme.md +0 -9
  4. demo/requirements.txt +0 -17
  5. hardware_accelerators/__init__.py +16 -1
  6. hardware_accelerators/analysis/__init__.py +0 -0
  7. hardware_accelerators/analysis/config.py +25 -0
  8. hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/config.mk +10 -0
  9. hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/constraint.sdc +1 -0
  10. hardware_accelerators/analysis/generate.py +958 -0
  11. hardware_accelerators/analysis/hardware_stats.py +458 -0
  12. hardware_accelerators/analysis/mnist_eval.py +274 -0
  13. hardware_accelerators/analysis/simple_circuits.py +258 -0
  14. hardware_accelerators/analysis/verilog_export.py +86 -0
  15. hardware_accelerators/analysis/verilog_output/pipelined_adder_BF16.v +37 -0
  16. hardware_accelerators/analysis/verilog_output/pipelined_adder_Float16.v +37 -0
  17. hardware_accelerators/analysis/verilog_output/pipelined_adder_Float32.v +37 -0
  18. hardware_accelerators/analysis/verilog_output/pipelined_adder_Float8.v +37 -0
  19. hardware_accelerators/analysis/verilog_output/pipelined_multiplier_BF16.v +37 -0
  20. hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float16.v +37 -0
  21. hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float32.v +37 -0
  22. hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float8.v +37 -0
  23. hardware_accelerators/analysis/verilog_output/simple_adder_BF16.v +21 -0
  24. hardware_accelerators/analysis/verilog_output/simple_adder_Float16.v +21 -0
  25. hardware_accelerators/analysis/verilog_output/simple_adder_Float32.v +21 -0
  26. hardware_accelerators/analysis/verilog_output/simple_adder_Float8.v +21 -0
  27. hardware_accelerators/analysis/verilog_output/simple_multiplier_BF16.v +21 -0
  28. hardware_accelerators/analysis/verilog_output/simple_multiplier_Float16.v +21 -0
  29. hardware_accelerators/analysis/verilog_output/simple_multiplier_Float32.v +21 -0
  30. hardware_accelerators/analysis/verilog_output/simple_multiplier_Float8.v +21 -0
  31. hardware_accelerators/app.py +388 -0
  32. hardware_accelerators/compile.py +167 -0
  33. hardware_accelerators/dtypes/__init__.py +3 -1
  34. hardware_accelerators/dtypes/base.py +12 -3
  35. hardware_accelerators/dtypes/bfloat16.py +4 -0
  36. hardware_accelerators/dtypes/float16.py +167 -0
  37. hardware_accelerators/dtypes/float32.py +174 -0
  38. hardware_accelerators/dtypes/float8.py +4 -0
  39. hardware_accelerators/nn/lmul.py +135 -0
  40. hardware_accelerators/nn/precision.py +264 -0
  41. hardware_accelerators/nn/precision_eval.py +280 -0
  42. hardware_accelerators/nn/run_precision_comparison.py +78 -0
  43. hardware_accelerators/nn/train.py +0 -2
  44. hardware_accelerators/nn/util.py +3 -1
  45. hardware_accelerators/rtllib/__init__.py +10 -2
  46. hardware_accelerators/rtllib/accelerator.py +407 -113
  47. hardware_accelerators/rtllib/activations.py +69 -7
  48. hardware_accelerators/rtllib/adders.py +54 -8
  49. hardware_accelerators/rtllib/legacy.py +71 -0
  50. hardware_accelerators/rtllib/lmul.py +63 -60
Dockerfile CHANGED
@@ -1,3 +1,4 @@
 
1
  FROM python:3.12-slim
2
 
3
  WORKDIR /code
@@ -9,20 +10,29 @@ RUN apt-get update && apt-get install -y \
9
  && rm -rf /var/lib/apt/lists/*
10
 
11
  # Install Python packages
12
- COPY demo/requirements.txt requirements.txt
13
  RUN pip install --no-cache-dir -r requirements.txt
14
 
15
  # Copy the model and application files
16
  COPY models/ /code/models/
17
  COPY hardware_accelerators/ /code/hardware_accelerators/
18
- COPY demo/app.py /code/app.py
19
 
20
- # Set environment variables for Gradio
 
 
 
21
  ENV GRADIO_SERVER_NAME=0.0.0.0
22
  ENV GRADIO_SERVER_PORT=7860
23
 
 
 
 
 
 
 
24
  # Expose the port Gradio runs on
25
  EXPOSE 7860
26
 
27
  # Command to run the Gradio app
28
- CMD ["python", "app.py"]
 
1
+ # Dockerfile for the demo
2
  FROM python:3.12-slim
3
 
4
  WORKDIR /code
 
10
  && rm -rf /var/lib/apt/lists/*
11
 
12
  # Install Python packages
13
+ COPY ./requirements.txt requirements.txt
14
  RUN pip install --no-cache-dir -r requirements.txt
15
 
16
  # Copy the model and application files
17
  COPY models/ /code/models/
18
  COPY hardware_accelerators/ /code/hardware_accelerators/
19
+ COPY results/component_data.csv /code/data/
20
 
21
+
22
+ # Set environment variables
23
+ ENV HWA_CACHE_DIR=/code/hardware_accelerators/cache
24
+ ENV COMPONENT_DATA_PATH=/code/data/component_data.csv
25
  ENV GRADIO_SERVER_NAME=0.0.0.0
26
  ENV GRADIO_SERVER_PORT=7860
27
 
28
+ # Copy the component data
29
+ COPY results/component_data.csv /code/data/component_data.csv
30
+
31
+ # Compile the simulations
32
+ RUN python3 -m hardware_accelerators.compile
33
+
34
  # Expose the port Gradio runs on
35
  EXPOSE 7860
36
 
37
  # Command to run the Gradio app
38
+ CMD ["python3", "-m", "hardware_accelerators.app"]
demo/app.py DELETED
@@ -1,312 +0,0 @@
1
- import sys
2
- import gradio as gr
3
- from gradio.components.image_editor import EditorValue
4
- import numpy as np
5
- import torch
6
- import pandas as pd
7
- from PIL import Image
8
- import torch
9
- import torch.nn as nn
10
- import torchvision.transforms as transforms
11
- import tqdm
12
- import pyrtl
13
-
14
- sys.path.append(".")
15
- from hardware_accelerators.nn.util import softmax
16
- from hardware_accelerators.simulation.matrix_utils import (
17
- bias_trick,
18
- count_total_gemv_tiles,
19
- generate_gemv_tiles,
20
- )
21
- from hardware_accelerators.rtllib.adders import float_adder
22
- from hardware_accelerators.rtllib.multipliers import float_multiplier
23
- from hardware_accelerators.dtypes.bfloat16 import BF16
24
- from hardware_accelerators.dtypes.float8 import Float8
25
- from hardware_accelerators.nn import model_factory, get_pytorch_device
26
- from hardware_accelerators.rtllib import (
27
- AcceleratorConfig,
28
- Accelerator,
29
- lmul_fast,
30
- float_multiplier,
31
- )
32
- from hardware_accelerators.simulation import CompiledSimulator, AcceleratorSimulator
33
-
34
-
35
- # ------------ CONSTANTS ------------ #
36
-
37
- # Load the trained model
38
- model_path = "models/mlp_mnist.pth"
39
- model = model_factory()
40
- model.load_state_dict(
41
- torch.load(model_path, map_location=get_pytorch_device(), weights_only=True)
42
- )
43
- model.eval()
44
-
45
- classes = [
46
- "zero",
47
- "one",
48
- "two",
49
- "three",
50
- "four",
51
- "five",
52
- "six",
53
- "seven",
54
- "eight",
55
- "nine",
56
- ]
57
- labels_value = {label: 0.0 for label in classes}
58
-
59
- accelerator_dtypes = ["float8", "bfloat16"]
60
- # accelerator_dtypes = ["float8", "float16", "bfloat16", "float32"]
61
-
62
- dtype_map = {"float8": Float8, "bfloat16": BF16}
63
-
64
- default_config = {
65
- "activations_dtype": "bfloat16",
66
- "weights_dtype": "bfloat16",
67
- "size": 4,
68
- "multiplication": "IEEE 754",
69
- }
70
-
71
- mult_map = {
72
- "IEEE 754": float_multiplier,
73
- "l-mul": lmul_fast,
74
- }
75
-
76
-
77
- # ------------ Event Listener Functions ------------ #
78
-
79
-
80
- def image_to_tensor(sketchpad: EditorValue):
81
- image = sketchpad["composite"]
82
- image = image.resize((28, 28), Image.Resampling.LANCZOS) # type: ignore
83
- img_array = np.transpose(np.array(image), (2, 0, 1))[-1]
84
-
85
- # Preprocessing: convert image to tensor and normalize
86
- transform = transforms.Compose(
87
- [
88
- transforms.ToTensor(),
89
- transforms.Normalize((0.1307,), (0.3081,)),
90
- ]
91
- )
92
- tensor_image = transform(img_array).unsqueeze(0) # Add batch dimension
93
- return tensor_image
94
-
95
-
96
- def torch_predict(sketchpad: EditorValue):
97
- tensor_image = image_to_tensor(sketchpad)
98
- with torch.no_grad():
99
- logits = model(tensor_image)
100
- probabilities = torch.softmax(logits, dim=1).squeeze(0)
101
- result = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
102
- return result
103
-
104
-
105
- def update_accelerator_config(
106
- activations_dtype: str, weights_dtype: str, size: int, multiplication: str
107
- ) -> AcceleratorConfig:
108
-
109
- # Triggered by run simulation button
110
- print("update_accelerator_config fn called")
111
- print(activations_dtype, weights_dtype, size, multiplication)
112
-
113
- return AcceleratorConfig(
114
- num_weight_tiles=4,
115
- weight_type=dtype_map[weights_dtype],
116
- data_type=dtype_map[activations_dtype],
117
- array_size=size,
118
- pe_multiplier=mult_map[multiplication],
119
- pe_adder=float_adder,
120
- accum_adder=float_adder,
121
- accum_addr_width=8,
122
- accum_type=dtype_map[activations_dtype],
123
- pipeline=False,
124
- )
125
-
126
-
127
- def simulator_predict(sketchpad: EditorValue, config: AcceleratorConfig):
128
- # if config == DEFAULT_ACCELERATOR_CONFIG:
129
- # sim = ACCELERATOR_SIM
130
- # else:
131
- # sim = AcceleratorSimulator(config=config)
132
-
133
- sim = CompiledSimulator(config=config)
134
- image = image_to_tensor(sketchpad).detach().numpy().flatten()
135
- probabilities = sim.run_mlp(model, image)
136
- result = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
137
- return result
138
-
139
-
140
- def sim_predict_progress(
141
- sketchpad: EditorValue,
142
- config: AcceleratorConfig,
143
- gr_progress=gr.Progress(track_tqdm=True),
144
- ):
145
- pyrtl.reset_working_block()
146
- simulator = CompiledSimulator(config=config)
147
- chunk_size = config.array_size
148
-
149
- x = image_to_tensor(sketchpad).detach().numpy().flatten()
150
- probabilities = simulator.run_mlp(model, x)
151
- return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
152
-
153
- weights_1 = model.fc1.weight.numpy(force=True)
154
- bias_1 = model.fc1.bias.numpy(force=True)
155
- weights_2 = model.fc2.weight.numpy(force=True)
156
- bias_2 = model.fc2.bias.numpy(force=True)
157
-
158
- # Add bias to first layer weights and 1 to activations
159
- W_aug, x_aug = bias_trick(weights_1, bias_1, x)
160
-
161
- total_tiles = count_total_gemv_tiles([(784, 128), (128, 10)], chunk_size)
162
- progress = tqdm.tqdm(total=total_tiles)
163
-
164
- tile_generator = generate_gemv_tiles(x_aug, W_aug, chunk_size)
165
-
166
- for tile in tile_generator:
167
- simulator.load_weights(weights=tile.matrix.T, tile_addr=0)
168
- simulator.execute_instruction(
169
- load_new_weights=True,
170
- weight_tile_addr=0,
171
- data_vec=tile.vector,
172
- accum_addr=tile.index,
173
- accum_mode=not tile.first,
174
- activation_func="relu",
175
- activation_enable=tile.last,
176
- flush_pipeline=True,
177
- )
178
- progress.update()
179
-
180
- simulator.execute_instruction(nop=True)
181
- simulator.execute_instruction(nop=True)
182
-
183
- sim_fc1 = np.array(simulator.output_trace)
184
-
185
- # simulator.reset_output_trace()
186
- simulator.output_trace = []
187
-
188
- W2_aug, fc1_aug = bias_trick(weights_2, bias_2, sim_fc1.flatten())
189
-
190
- fc2_tile_generator = generate_gemv_tiles(fc1_aug, W2_aug, chunk_size)
191
-
192
- for tile in fc2_tile_generator:
193
- simulator.load_weights(weights=tile.matrix.T, tile_addr=0)
194
- simulator.execute_instruction(
195
- load_new_weights=True,
196
- weight_tile_addr=0,
197
- data_vec=tile.vector,
198
- accum_addr=tile.index,
199
- accum_mode=not tile.first,
200
- activation_enable=tile.last,
201
- flush_pipeline=True,
202
- )
203
- progress.update()
204
-
205
- simulator.execute_instruction(nop=True)
206
- simulator.execute_instruction(nop=True)
207
-
208
- sim_fc2 = np.array(simulator.output_trace).flatten()
209
- probabilities = softmax(sim_fc2)
210
- result = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
211
- return result
212
-
213
-
214
- # ------------ Blocks UI Layout ------------ #
215
-
216
- with gr.Blocks(fill_height=False) as demo:
217
-
218
- accelerator_config = gr.State()
219
-
220
- gr.Markdown("## Draw a digit to see the model's prediction")
221
- with gr.Row(equal_height=True):
222
- with gr.Column():
223
- sketchpad = gr.Sketchpad(
224
- # label="Draw a digit",
225
- type="pil", # Changed to PIL
226
- transforms=(),
227
- layers=False,
228
- canvas_size=(400, 400),
229
- )
230
-
231
- with gr.Row():
232
- predict_btn = gr.Button("Run Hardware Simulation", variant="primary")
233
-
234
- # with gr.Accordion("Accelerator Configuration", open=True):
235
- with gr.Group():
236
- weight_dtype_component = gr.Radio(
237
- label="Weights d-type",
238
- choices=accelerator_dtypes,
239
- value=default_config["weights_dtype"],
240
- interactive=True,
241
- )
242
- activation_dtype_component = gr.Radio(
243
- label="Activations d-type",
244
- choices=accelerator_dtypes,
245
- value=default_config["activations_dtype"],
246
- interactive=True,
247
- )
248
- systolic_array_size_component = gr.Slider(
249
- label="Systolic Array Size",
250
- info="Large values will significantly slow down the simulation",
251
- minimum=2,
252
- maximum=16,
253
- step=1,
254
- value=default_config["size"],
255
- interactive=True,
256
- )
257
- multiply_component = gr.Radio(
258
- label="Multiplication Type",
259
- choices=["IEEE 754", "l-mul"],
260
- value=default_config["multiplication"],
261
- interactive=True,
262
- )
263
-
264
- with gr.Column():
265
- pytorch_output = gr.Label(
266
- label="Pytorch Ground Truth Predictions", value=labels_value
267
- )
268
-
269
- sim_output = gr.Label(
270
- label="Hardware Simulator Predictions", value=labels_value
271
- )
272
-
273
- # ------------ Event Listeners ------------ #
274
-
275
- sketchpad.input(
276
- fn=torch_predict,
277
- inputs=sketchpad,
278
- outputs=pytorch_output,
279
- )
280
-
281
- # TODO: implement simulator_predict
282
- predict_btn.click(
283
- fn=update_accelerator_config,
284
- inputs=[
285
- activation_dtype_component,
286
- weight_dtype_component,
287
- systolic_array_size_component,
288
- multiply_component,
289
- ],
290
- outputs=accelerator_config,
291
- ).then(
292
- fn=sim_predict_progress,
293
- inputs=[sketchpad, accelerator_config],
294
- outputs=sim_output,
295
- )
296
-
297
- # gr.on(
298
- # fn=update_accelerator_config,
299
- # inputs=[
300
- # activation_dtype_component,
301
- # weight_dtype_component,
302
- # systolic_array_size_component,
303
- # multiply_component,
304
- # ],
305
- # outputs=accelerator_config,
306
- # )
307
-
308
- # ------------
309
-
310
- if __name__ == "__main__":
311
- demo.queue()
312
- demo.launch(share=False)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
demo/readme.md DELETED
@@ -1,9 +0,0 @@
1
- # Interactive Demo
2
-
3
- This directory contains a demo where you can test out and compare the performance of the hardware accelerator with a software implementation of the same model. The demo is built using [Gradio](https://gradio.app/), a Python library for creating interactive web applications.
4
-
5
- ## Running the Demo
6
-
7
- Install the project requirements, then from the repo root simply run `python demo/app.py` to start the demo.
8
-
9
- We will also provide a ready to go Docker image soon!
 
 
 
 
 
 
 
 
 
 
demo/requirements.txt DELETED
@@ -1,17 +0,0 @@
1
- jupyter==1.1.1
2
- ipykernel==6.29.5
3
- tqdm==4.67.0
4
- numpy==2.2.1
5
- ipython==8.12.3
6
- isort==5.13.2
7
- numpy==2.2.1
8
- pandas==2.2.3
9
- pyrtl==0.11.2
10
- matplotlib==3.10.0
11
- pytest==8.3.4
12
- torch==2.4.1
13
- torchvision==0.19.1
14
- onnx==1.17.0
15
- netron==8.1.3
16
- gradio==5.16.0
17
- black[jupyter]==24.10.0
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
hardware_accelerators/__init__.py CHANGED
@@ -1,4 +1,8 @@
1
- from .dtypes import BF16, Float8
 
 
 
 
2
  from .rtllib import (
3
  FloatAdderPipelined,
4
  FloatMultiplierPipelined,
@@ -8,10 +12,21 @@ from .rtllib import (
8
  lmul_fast,
9
  lmul_simple,
10
  )
 
 
 
 
 
 
11
 
12
  __all__ = [
 
 
 
13
  "Float8",
14
  "BF16",
 
 
15
  "float_adder",
16
  "FloatAdderPipelined",
17
  "float_multiplier",
 
1
+ from dotenv import load_dotenv
2
+
3
+ load_dotenv()
4
+
5
+ from .dtypes import BF16, Float8, Float16, Float32
6
  from .rtllib import (
7
  FloatAdderPipelined,
8
  FloatMultiplierPipelined,
 
12
  lmul_fast,
13
  lmul_simple,
14
  )
15
+ from .simulation import (
16
+ get_sim_cache_dir,
17
+ set_sim_cache_dir,
18
+ CompiledAcceleratorSimulator,
19
+ )
20
+
21
 
22
  __all__ = [
23
+ "get_sim_cache_dir",
24
+ "set_sim_cache_dir",
25
+ "CompiledAcceleratorSimulator",
26
  "Float8",
27
  "BF16",
28
+ "Float16",
29
+ "Float32",
30
  "float_adder",
31
  "FloatAdderPipelined",
32
  "float_multiplier",
hardware_accelerators/analysis/__init__.py ADDED
File without changes
hardware_accelerators/analysis/config.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from ..dtypes import *
2
+ from ..rtllib.lmul import lmul_fast, lmul_simple
3
+ from ..rtllib.multipliers import float_multiplier
4
+
5
+
6
+ NN_TEST_BATCH_SIZE = 64
7
+
8
+ NN_TEST_SYSTOLIC_ARRAY_SIZE = 8
9
+
10
+ NN_TEST_ACCUM_ADDR_WIDTH = 12
11
+
12
+ NN_TEST_MUL_FNS = [
13
+ float_multiplier,
14
+ lmul_simple,
15
+ # lmul_fast,
16
+ ]
17
+
18
+ NN_TEST_WA_DTYPES = [
19
+ # (Float8, Float8),
20
+ (Float8, BF16),
21
+ (Float8, Float32),
22
+ (BF16, BF16),
23
+ # (BF16, Float32),
24
+ # (Float32, Float32),
25
+ ]
hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/config.mk ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ export DESIGN_NAME = lmul_pipelined_fast
2
+ export PLATFORM = nangate45
3
+ export VERILOG_FILES = $(DESIGN_DIR)/src/lmul_pipelined_fast.v
4
+ export SDC_FILE = $(DESIGN_DIR)/constraint.sdc
5
+
6
+ # These values must be multiples of placement site
7
+ export DIE_AREA = 0 0 100 100
8
+ export CORE_AREA = 10 10 90 90
9
+
10
+ export CLOCK_PERIOD = 1.0
hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/constraint.sdc ADDED
@@ -0,0 +1 @@
 
 
1
+ create_clock -name clk -period 1.0 [get_ports {clk}]
hardware_accelerators/analysis/generate.py ADDED
@@ -0,0 +1,958 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from h11 import Data
3
+ from pandas import DataFrame
4
+ import pyrtl
5
+ from itertools import product
6
+ from pyrtl import *
7
+ from dataclasses import dataclass
8
+ from typing import Callable, Type, Literal, Optional
9
+
10
+ from .verilog_export import export_to_verilog
11
+ from ..dtypes import *
12
+ from ..rtllib import *
13
+ from ..rtllib.processing_element import ProcessingElement
14
+ from ..rtllib.adders import *
15
+ from ..rtllib.multipliers import *
16
+ from ..rtllib.lmul import *
17
+ from ..rtllib.utils.common import *
18
+ from ..simulation.utils import *
19
+
20
+
21
+ def create_inputs(**named_bitwidths):
22
+ """
23
+ Create PyRTL Input wires with specified bitwidths.
24
+
25
+ Args:
26
+ **named_bitwidths: Named bitwidths where the key is used as the wire name
27
+
28
+ Returns:
29
+ Generator of PyRTL Input wires
30
+
31
+ Note:
32
+ You must use all keyword arguments
33
+ """
34
+
35
+ # If using keyword arguments
36
+ for name, bitwidth in named_bitwidths.items():
37
+ yield pyrtl.Input(bitwidth, name=name) # type: ignore
38
+
39
+
40
+ def create_outputs(*args, **named_wires):
41
+ """
42
+ Create PyRTL Output wires connected to the input wires.
43
+
44
+ Args:
45
+ *args: Variable number of wires to connect to unnamed outputs
46
+ **named_wires: Named wires where the key is used as the output wire name
47
+
48
+ Note:
49
+ You must use either all positional arguments or all keyword arguments, not a mix.
50
+ """
51
+ if args and named_wires:
52
+ raise ValueError(
53
+ "Please use either all positional arguments or all keyword arguments, not a mix."
54
+ )
55
+
56
+ # If using positional arguments
57
+ for wire in args:
58
+ out = pyrtl.Output(len(wire), name=wire.name.replace("tmp", "out")) # type: ignore
59
+ out <<= wire
60
+
61
+ # If using keyword arguments
62
+ for name, wire in named_wires.items():
63
+ out = pyrtl.Output(len(wire), name=name) # type: ignore
64
+ out <<= wire
65
+
66
+
67
+ @dataclass
68
+ class RTLAnalysis:
69
+ """Results of RTL analysis."""
70
+
71
+ max_delay: float
72
+ max_freq: float
73
+ logic_area: float
74
+ mem_area: float
75
+ name: Optional[str] = None
76
+
77
+ def __repr__(self):
78
+ if self.name is None:
79
+ return (
80
+ f"RTLAnalysisResults("
81
+ f"max_delay={self.max_delay:.2f} ps, "
82
+ f"max_freq={self.max_freq:.2f} MHz, "
83
+ f"logic_area={self.logic_area:.2f}um², "
84
+ f"mem_area={self.mem_area:.2f}um²)"
85
+ )
86
+ else:
87
+ return (
88
+ f"RTLAnalysisResults for {self.name}:\n\t"
89
+ f"max_delay={self.max_delay:.2f} ps\n\t"
90
+ f"max_freq={self.max_freq:.2f} MHz\n\t"
91
+ f"logic_area={self.logic_area:.2f}um²\n\t"
92
+ f"mem_area={self.mem_area:.2f}um²"
93
+ )
94
+
95
+
96
+ def analyze(
97
+ block: Block | None = None, synth: bool = True, opt: bool = True, name=None
98
+ ):
99
+ if block is not None:
100
+ pyrtl.set_working_block(block)
101
+
102
+ if synth:
103
+ pyrtl.synthesize()
104
+ if opt:
105
+ pyrtl.optimize()
106
+
107
+ timing = pyrtl.TimingAnalysis()
108
+ max_delay = timing.max_length()
109
+ max_freq = timing.max_freq()
110
+ logic_area, mem_area = pyrtl.area_estimation()
111
+
112
+ return RTLAnalysis(
113
+ name=name,
114
+ max_delay=max_delay,
115
+ max_freq=max_freq,
116
+ logic_area=logic_area * 1e6,
117
+ mem_area=mem_area * 1e6,
118
+ )
119
+
120
+
121
+ def create_adder_blocks(dtype: Type[BaseFloat]) -> dict[str, Block]:
122
+ bits = dtype.bitwidth()
123
+ e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
124
+
125
+ combinational_block = pyrtl.Block()
126
+ combinational_fast_block = pyrtl.Block()
127
+ adder_pipelined_block = pyrtl.Block()
128
+ adder_pipelined_fast_block = pyrtl.Block()
129
+ stage_2_block = pyrtl.Block()
130
+ stage_2_fast_block = pyrtl.Block()
131
+ stage_3_block = pyrtl.Block()
132
+ stage_4_block = pyrtl.Block()
133
+ stage_4_fast_block = pyrtl.Block()
134
+ stage_5_block = pyrtl.Block()
135
+
136
+ # Combinational design
137
+ with set_working_block(combinational_block):
138
+ create_outputs(
139
+ *float_adder(
140
+ *create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=False
141
+ )
142
+ )
143
+
144
+ with set_working_block(combinational_fast_block):
145
+ create_outputs(
146
+ *float_adder(
147
+ *create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=True
148
+ )
149
+ )
150
+
151
+ # Complete pipelined design
152
+ with set_working_block(adder_pipelined_block):
153
+ create_outputs(
154
+ float_adder_pipelined(
155
+ *create_inputs(float_a=bits, float_b=bits),
156
+ dtype=dtype,
157
+ fast=False,
158
+ )
159
+ )
160
+
161
+ with set_working_block(adder_pipelined_fast_block):
162
+ create_outputs(
163
+ float_adder_pipelined(
164
+ *create_inputs(float_a=bits, float_b=bits),
165
+ dtype=dtype,
166
+ fast=True,
167
+ )
168
+ )
169
+
170
+ # Stages 1 & 2
171
+ with set_working_block(stage_2_block):
172
+ float_components = extract_float_components(
173
+ *create_inputs(float_a=bits, float_b=bits),
174
+ e_bits=e_bits,
175
+ m_bits=m_bits,
176
+ )
177
+ stage_2_outputs = adder_stage_2(
178
+ *float_components,
179
+ e_bits,
180
+ m_bits,
181
+ fast=False,
182
+ )
183
+ create_outputs(*stage_2_outputs)
184
+
185
+ with set_working_block(stage_2_fast_block):
186
+ float_components = extract_float_components(
187
+ *create_inputs(float_a=bits, float_b=bits),
188
+ e_bits=e_bits,
189
+ m_bits=m_bits,
190
+ )
191
+ stage_2_outputs = adder_stage_2(
192
+ *float_components,
193
+ e_bits,
194
+ m_bits,
195
+ fast=True,
196
+ )
197
+ create_outputs(*stage_2_outputs)
198
+
199
+ # Stage 3
200
+ with set_working_block(stage_3_block):
201
+ # Perform alignment and generate SGR bits
202
+ stage_3_outputs = adder_stage_3(
203
+ *create_inputs(mant_smaller=m_bits + 1, shift_amount=e_bits),
204
+ e_bits=e_bits,
205
+ m_bits=m_bits,
206
+ )
207
+ create_outputs(*stage_3_outputs)
208
+
209
+ # Stage 4
210
+ with set_working_block(stage_4_block):
211
+ # Perform mantissa addition and leading zero detection
212
+ stage_4_outputs = adder_stage_4(
213
+ *create_inputs(mant_aligned=m_bits + 1, mant_unchanged=m_bits + 1, s_xor=1),
214
+ m_bits=m_bits,
215
+ fast=False,
216
+ )
217
+ create_outputs(*stage_4_outputs)
218
+
219
+ with set_working_block(stage_4_fast_block):
220
+ # Perform mantissa addition and leading zero detection
221
+ stage_4_outputs = adder_stage_4(
222
+ *create_inputs(mant_aligned=m_bits + 1, mant_unchanged=m_bits + 1, s_xor=1),
223
+ m_bits=m_bits,
224
+ fast=True,
225
+ )
226
+ create_outputs(*stage_4_outputs)
227
+
228
+ # Stage 5
229
+ with set_working_block(stage_5_block):
230
+ # Perform normalization, rounding, and final assembly
231
+ stage_5_outputs = adder_stage_5(
232
+ *create_inputs(
233
+ abs_mantissa=m_bits + 2,
234
+ sticky_bit=1,
235
+ guard_bit=1,
236
+ round_bit=1,
237
+ lzc=4,
238
+ exp_larger=e_bits,
239
+ sign_a=1,
240
+ sign_b=1,
241
+ exp_diff=e_bits + 1,
242
+ is_neg=1,
243
+ ),
244
+ e_bits=e_bits,
245
+ m_bits=m_bits,
246
+ )
247
+ create_outputs(*stage_5_outputs)
248
+
249
+ # Return all the generated blocks for analysis
250
+ return {
251
+ "adder_combinational": combinational_block,
252
+ "adder_combinational_fast": combinational_fast_block,
253
+ "adder_pipelined": adder_pipelined_block,
254
+ "adder_pipelined_fast": adder_pipelined_fast_block,
255
+ "adder_stage_2": stage_2_block,
256
+ "adder_stage_2_fast": stage_2_fast_block,
257
+ "adder_stage_3": stage_3_block,
258
+ "adder_stage_4": stage_4_block,
259
+ "adder_stage_4_fast": stage_4_fast_block,
260
+ "adder_stage_5": stage_5_block,
261
+ }
262
+
263
+
264
+ def create_multiplier_blocks(dtype: Type[BaseFloat], fast: bool) -> dict[str, Block]:
265
+ bits = dtype.bitwidth()
266
+ e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
267
+
268
+ combinational_block = pyrtl.Block()
269
+ multiplier_block = pyrtl.Block()
270
+ stage_2_block = pyrtl.Block()
271
+ stage_3_block = pyrtl.Block()
272
+ stage_4_block = pyrtl.Block()
273
+
274
+ # Combinational design
275
+ with set_working_block(combinational_block):
276
+ create_outputs(
277
+ float_multiplier(
278
+ *create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=fast
279
+ )
280
+ )
281
+
282
+ # Complete pipelined design
283
+ with set_working_block(multiplier_block):
284
+ multiplier = FloatMultiplierPipelined(
285
+ *create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=fast
286
+ )
287
+ create_outputs(multiplier._result)
288
+
289
+ # Stage 1 & 2: Extract components and calculate sign, exponent sum, mantissa product
290
+ with set_working_block(stage_2_block):
291
+ float_components = extract_float_components(
292
+ *create_inputs(float_a=bits, float_b=bits),
293
+ e_bits=e_bits,
294
+ m_bits=m_bits,
295
+ )
296
+ stage_2_outputs = multiplier_stage_2(
297
+ *float_components,
298
+ m_bits,
299
+ fast,
300
+ )
301
+ create_outputs(*stage_2_outputs)
302
+
303
+ # Stage 3: Leading zero detection and exponent adjustment
304
+ with set_working_block(stage_3_block):
305
+ stage_3_outputs = multiplier_stage_3(
306
+ *create_inputs(exp_sum=e_bits + 1, mant_product=2 * m_bits + 2),
307
+ e_bits=e_bits,
308
+ m_bits=m_bits,
309
+ fast=fast,
310
+ )
311
+ create_outputs(*stage_3_outputs)
312
+
313
+ # Stage 4: Normalization, rounding, and final assembly
314
+ with set_working_block(stage_4_block):
315
+ stage_4_outputs = multiplier_stage_4(
316
+ *create_inputs(
317
+ unbiased_exp=e_bits,
318
+ leading_zeros=e_bits,
319
+ mantissa_product=2 * m_bits + 2,
320
+ ),
321
+ m_bits=m_bits,
322
+ e_bits=e_bits,
323
+ fast=fast,
324
+ )
325
+ create_outputs(*stage_4_outputs)
326
+
327
+ # Return all the generated blocks for analysis
328
+ faststr = "_fast" if fast else ""
329
+ return {
330
+ f"multiplier_combinational{faststr}": combinational_block,
331
+ f"multiplier{faststr}": multiplier_block,
332
+ f"multiplier_stage_2{faststr}": stage_2_block,
333
+ f"multiplier_stage_3{faststr}": stage_3_block,
334
+ f"multiplier_stage_4{faststr}": stage_4_block,
335
+ }
336
+
337
+
338
+ def create_lmul_blocks(dtype: Type[BaseFloat]) -> dict[str, Block]:
339
+ bits = dtype.bitwidth()
340
+
341
+ combinational_block = pyrtl.Block()
342
+ combinational_fast_block = pyrtl.Block()
343
+ pipelined_block = pyrtl.Block()
344
+ pipelined_fast_block = pyrtl.Block()
345
+
346
+ # Combinational design (simple)
347
+ with set_working_block(combinational_block):
348
+ create_outputs(
349
+ lmul_simple(*create_inputs(float_a=bits, float_b=bits), dtype=dtype)
350
+ )
351
+
352
+ # Combinational design (fast)
353
+ with set_working_block(combinational_fast_block):
354
+ create_outputs(
355
+ lmul_fast(*create_inputs(float_a=bits, float_b=bits), dtype=dtype)
356
+ )
357
+
358
+ # Pipelined design (simple)
359
+ with set_working_block(pipelined_block):
360
+ mult = LmulPipelined(
361
+ *create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=False
362
+ )
363
+ create_outputs(mult.output_reg)
364
+
365
+ # Pipelined design (fast)
366
+ with set_working_block(pipelined_fast_block):
367
+ mult = LmulPipelined(
368
+ *create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=True
369
+ )
370
+ create_outputs(mult.output_reg)
371
+
372
+ # Return all the generated blocks for analysis
373
+ return {
374
+ "lmul_combinational_simple": combinational_block,
375
+ "lmul_combinational_fast": combinational_fast_block,
376
+ "lmul_pipelined_simple": pipelined_block,
377
+ "lmul_pipelined_fast": pipelined_fast_block,
378
+ }
379
+
380
+
381
+ def connect_pe_io(pe: ProcessingElement):
382
+ # Connect the inputs and outputs of the processing element
383
+ w_bits, a_bits = pe.weight_type.bitwidth(), pe.data_type.bitwidth()
384
+ w_in, d_in, acc_in = create_inputs(
385
+ weight_in=w_bits, data_in=a_bits, accum_in=a_bits
386
+ )
387
+ pe.connect_weight(w_in)
388
+ pe.connect_data(d_in)
389
+ pe.connect_accum(acc_in)
390
+ pe.connect_control_signals(
391
+ *create_inputs(weight_en=1, data_en=1, mul_en=1, adder_en=1)
392
+ )
393
+ create_outputs(*pe.outputs.__dict__.values())
394
+
395
+
396
+ def create_pe_blocks(
397
+ dtypes: tuple[Type[BaseFloat], Type[BaseFloat]],
398
+ ) -> dict[str, Block]:
399
+ """Create a processing element for each pair of dtypes."""
400
+
401
+ weight_dtype, act_dtype = dtypes
402
+
403
+ # Defining blocks to encapsulate hardware
404
+
405
+ combinational_block = Block()
406
+ simple_pipeline_block = Block()
407
+ simple_pipeline_fast_block = Block()
408
+ full_pipeline_block = Block()
409
+ full_pipeline_fast_block = Block()
410
+
411
+ combinational_lmul_block = Block()
412
+ simple_pipeline_lmul_block = Block()
413
+ simple_pipeline_fast_lmul_block = Block()
414
+ full_pipeline_lmul_block = Block()
415
+ full_pipeline_fast_lmul_block = Block()
416
+
417
+ # Standard IEEE multiplier versions
418
+
419
+ with set_working_block(combinational_block):
420
+ pe = ProcessingElement(
421
+ data_type=act_dtype,
422
+ weight_type=weight_dtype,
423
+ accum_type=act_dtype,
424
+ multiplier=float_multiplier,
425
+ adder=float_adder,
426
+ pipeline_mult=False,
427
+ )
428
+ connect_pe_io(pe)
429
+
430
+ with set_working_block(simple_pipeline_block):
431
+ pe = ProcessingElement(
432
+ data_type=act_dtype,
433
+ weight_type=weight_dtype,
434
+ accum_type=act_dtype,
435
+ multiplier=float_multiplier,
436
+ adder=float_adder,
437
+ pipeline_mult=True,
438
+ )
439
+ connect_pe_io(pe)
440
+
441
+ with set_working_block(simple_pipeline_fast_block):
442
+ pe = ProcessingElement(
443
+ data_type=act_dtype,
444
+ weight_type=weight_dtype,
445
+ accum_type=act_dtype,
446
+ multiplier=float_multiplier_fast_unstable,
447
+ adder=float_adder_fast_unstable,
448
+ pipeline_mult=True,
449
+ )
450
+ connect_pe_io(pe)
451
+
452
+ with set_working_block(full_pipeline_block):
453
+ pe = ProcessingElement(
454
+ data_type=act_dtype,
455
+ weight_type=weight_dtype,
456
+ accum_type=act_dtype,
457
+ multiplier=float_multiplier_pipelined,
458
+ adder=float_adder_pipelined,
459
+ pipeline_mult=True,
460
+ )
461
+ connect_pe_io(pe)
462
+
463
+ with set_working_block(full_pipeline_fast_block):
464
+ pe = ProcessingElement(
465
+ data_type=act_dtype,
466
+ weight_type=weight_dtype,
467
+ accum_type=act_dtype,
468
+ multiplier=float_multiplier_pipelined_fast_unstable,
469
+ adder=float_adder_pipelined_fast_unstable,
470
+ pipeline_mult=True,
471
+ )
472
+ connect_pe_io(pe)
473
+
474
+ # L-mul versions
475
+
476
+ with set_working_block(combinational_lmul_block):
477
+ pe = ProcessingElement(
478
+ data_type=act_dtype,
479
+ weight_type=weight_dtype,
480
+ accum_type=act_dtype,
481
+ multiplier=lmul_simple,
482
+ adder=float_adder,
483
+ pipeline_mult=False,
484
+ )
485
+ connect_pe_io(pe)
486
+
487
+ with set_working_block(simple_pipeline_lmul_block):
488
+ pe = ProcessingElement(
489
+ data_type=act_dtype,
490
+ weight_type=weight_dtype,
491
+ accum_type=act_dtype,
492
+ multiplier=lmul_simple,
493
+ adder=float_adder,
494
+ pipeline_mult=True,
495
+ )
496
+ connect_pe_io(pe)
497
+
498
+ with set_working_block(simple_pipeline_fast_lmul_block):
499
+ pe = ProcessingElement(
500
+ data_type=act_dtype,
501
+ weight_type=weight_dtype,
502
+ accum_type=act_dtype,
503
+ multiplier=lmul_fast,
504
+ adder=float_adder_fast_unstable,
505
+ pipeline_mult=True,
506
+ )
507
+ connect_pe_io(pe)
508
+
509
+ with set_working_block(full_pipeline_lmul_block):
510
+ pe = ProcessingElement(
511
+ data_type=act_dtype,
512
+ weight_type=weight_dtype,
513
+ accum_type=act_dtype,
514
+ multiplier=lmul_pipelined,
515
+ adder=float_adder_pipelined,
516
+ pipeline_mult=True,
517
+ )
518
+ connect_pe_io(pe)
519
+
520
+ with set_working_block(full_pipeline_fast_lmul_block):
521
+ pe = ProcessingElement(
522
+ data_type=act_dtype,
523
+ weight_type=weight_dtype,
524
+ accum_type=act_dtype,
525
+ multiplier=lmul_pipelined_fast,
526
+ adder=float_adder_pipelined_fast_unstable,
527
+ pipeline_mult=True,
528
+ )
529
+ connect_pe_io(pe)
530
+
531
+ return {
532
+ "pe_combinational": combinational_block,
533
+ "pe_standard": simple_pipeline_block,
534
+ "pe_fast": simple_pipeline_fast_block,
535
+ "pe_pipelined": full_pipeline_block,
536
+ "pe_fast_pipelined": full_pipeline_fast_block,
537
+ "pe_combinational_lmul": combinational_lmul_block,
538
+ "pe_standard_lmul": simple_pipeline_lmul_block,
539
+ "pe_fast_lmul": simple_pipeline_fast_lmul_block,
540
+ "pe_pipelined_lmul": full_pipeline_lmul_block,
541
+ "pe_fast_pipelined_lmul": full_pipeline_fast_lmul_block,
542
+ }
543
+
544
+
545
+ def create_accelerator_blocks(
546
+ dtypes: tuple[Type[BaseFloat], Type[BaseFloat]],
547
+ array_size: int = 4,
548
+ addr_bits: int = 12,
549
+ ) -> dict[str, Block]:
550
+ """
551
+ Create accelerator blocks for all valid configurations based on the given inputs.
552
+
553
+ Args:
554
+ dtypes: Tuple of (weight_type, activation_type) data types
555
+ array_size: Size of the systolic array (N x N)
556
+ addr_bits: Bit width for accumulator address (uses default if None)
557
+
558
+ Returns:
559
+ Dictionary mapping configuration names to PyRTL blocks
560
+ """
561
+ weight_type, activation_type = dtypes
562
+
563
+ # Define all valid configurations to test
564
+ pipeline_options = [None, "low", "high"]
565
+ lmul_options = [False, True]
566
+ fast_options = [False, True]
567
+
568
+ # Create configs and blocks
569
+ blocks = {}
570
+ for pipeline, lmul, fast in product(pipeline_options, lmul_options, fast_options):
571
+ if pipeline is None and fast is True:
572
+ continue
573
+
574
+ # Create the configuration
575
+ config = AcceleratorAnalysisConfig(
576
+ array_size=array_size,
577
+ activation_type=activation_type,
578
+ weight_type=weight_type,
579
+ lmul=lmul,
580
+ accum_addr_width=addr_bits,
581
+ pipeline_level=pipeline,
582
+ use_fast_internals=fast,
583
+ )
584
+
585
+ block = pyrtl.Block()
586
+ with set_working_block(block):
587
+ AcceleratorTopLevel(config)
588
+
589
+ blocks[config.name] = block
590
+
591
+ return blocks
592
+
593
+
594
+ ################################################################
595
+
596
+ # if __name__ == "__main__":
597
+
598
+ # OUTPUT_DIR = Path("verilog")
599
+ # POSTSYNTH_DIR = OUTPUT_DIR / "pyrtl_synth"
600
+
601
+ # EXPORT_PRE_SYNTH = False
602
+ # EXPORT_POST_SYNTH = True
603
+ # RUN_ANALYSIS = True
604
+ # ANALYSIS_RESULT_DIR = Path("results")
605
+
606
+ # array_size = 8
607
+ # addr_bits = 12
608
+
609
+ # dtype_list = [Float8, BF16, Float32]
610
+
611
+ # dtype_names = {Float8: "fp8", BF16: "bf16", Float32: "fp32"}
612
+
613
+ # weight_act_dtypes = [
614
+ # (Float8, Float8),
615
+ # (Float8, BF16),
616
+ # (Float8, Float32),
617
+ # (BF16, BF16),
618
+ # (BF16, Float32),
619
+ # (Float32, Float32),
620
+ # ]
621
+
622
+ # # Hardware building blocks
623
+ # basic_component_analysis = []
624
+
625
+ # for dtype in dtype_list:
626
+ # block_dicts = [
627
+ # ("adder", create_adder_blocks(dtype)),
628
+ # ("multiplier", create_multiplier_blocks(dtype, fast=False)),
629
+ # ("multiplier", create_multiplier_blocks(dtype, fast=True)),
630
+ # ("lmul", create_lmul_blocks(dtype)),
631
+ # ]
632
+ # for component_name, block_dict in block_dicts:
633
+ # for name, block in block_dict.items():
634
+ # output_path = Path(component_name, dtype_names[dtype], f"{name}.v")
635
+ # if EXPORT_PRE_SYNTH:
636
+ # export_to_verilog(block, OUTPUT_DIR / output_path)
637
+ # if RUN_ANALYSIS:
638
+ # analysis_result = analyze(block, name=name)
639
+ # analysis_result.dtype = dtype_names[dtype]
640
+ # analysis_result.component = component_name
641
+ # basic_component_analysis.append(analysis_result.__dict__)
642
+ # if EXPORT_POST_SYNTH:
643
+ # export_to_verilog(block, POSTSYNTH_DIR / output_path)
644
+
645
+ # # More complex hardware
646
+ # pe_analysis = []
647
+ # accelerator_analysis = []
648
+
649
+ # for weight_dtype, act_dtype in weight_act_dtypes:
650
+ # folder_name = f"w{weight_dtype.bitwidth()}a{act_dtype.bitwidth()}"
651
+
652
+ # pe_blocks = create_pe_blocks((weight_dtype, act_dtype))
653
+ # for name, block in pe_blocks.items():
654
+ # pe_output_path = Path("pe", folder_name, f"{name}.v")
655
+ # if EXPORT_PRE_SYNTH:
656
+ # export_to_verilog(block, OUTPUT_DIR / pe_output_path)
657
+ # if RUN_ANALYSIS:
658
+ # analysis_result = analyze(block, name=name)
659
+ # analysis_result.weights = dtype_names[weight_dtype]
660
+ # analysis_result.activations = dtype_names[act_dtype]
661
+ # analysis_result.component = "pe"
662
+ # pe_analysis.append(analysis_result.__dict__)
663
+ # if EXPORT_POST_SYNTH:
664
+ # export_to_verilog(block, POSTSYNTH_DIR / pe_output_path)
665
+
666
+ # accelerator_blocks = create_accelerator_blocks(
667
+ # (weight_dtype, act_dtype), array_size, addr_bits
668
+ # )
669
+ # for name, block in accelerator_blocks.items():
670
+ # accelerator_output_path = Path("accelerator", folder_name, f"{name}.v")
671
+ # if EXPORT_PRE_SYNTH:
672
+ # export_to_verilog(block, OUTPUT_DIR / accelerator_output_path)
673
+ # if RUN_ANALYSIS:
674
+ # analysis_result = analyze(block, name=name)
675
+ # analysis_result.weights = dtype_names[weight_dtype]
676
+ # analysis_result.activations = dtype_names[act_dtype]
677
+ # analysis_result.component = "accelerator"
678
+ # accelerator_analysis.append(analysis_result.__dict__)
679
+ # if EXPORT_POST_SYNTH:
680
+ # export_to_verilog(block, POSTSYNTH_DIR / accelerator_output_path)
681
+
682
+ # if RUN_ANALYSIS:
683
+ # DataFrame(basic_component_analysis).to_csv(
684
+ # ANALYSIS_RESULT_DIR / "component_analysis.csv", index=False
685
+ # )
686
+ # DataFrame(pe_analysis).to_csv(
687
+ # ANALYSIS_RESULT_DIR / "pe_analysis.csv", index=False
688
+ # )
689
+ # DataFrame(accelerator_analysis).to_csv(
690
+ # ANALYSIS_RESULT_DIR / "accelerator_analysis.csv", index=False
691
+ # )
692
+
693
+
694
+ import multiprocessing as mp
695
+ from pathlib import Path
696
+ import os
697
+ import csv
698
+ import time
699
+ from functools import partial
700
+ from pandas import DataFrame
701
+ import json
702
+ import traceback
703
+
704
+
705
+ def process_block(
706
+ block,
707
+ name,
708
+ output_dir,
709
+ postsynth_dir,
710
+ export_pre_synth,
711
+ export_post_synth,
712
+ run_analysis,
713
+ analysis_result_dir,
714
+ component_name=None,
715
+ dtype=None,
716
+ weight_dtype=None,
717
+ act_dtype=None,
718
+ dtype_names=None,
719
+ output_path=None,
720
+ ):
721
+ """Process a single block with optional export and analysis"""
722
+ result = None
723
+ try:
724
+ if export_pre_synth and output_path:
725
+ os.makedirs((output_dir / output_path).parent, exist_ok=True)
726
+ export_to_verilog(block, output_dir / output_path)
727
+
728
+ if run_analysis:
729
+ analysis_result = analyze(block, name=name)
730
+
731
+ # Set appropriate attributes based on the component type
732
+ if component_name and dtype:
733
+ analysis_result.dtype = dtype_names[dtype]
734
+ analysis_result.component = component_name
735
+ result_type = "component"
736
+ elif weight_dtype and act_dtype:
737
+ analysis_result.weights = dtype_names[weight_dtype]
738
+ analysis_result.activations = dtype_names[act_dtype]
739
+ analysis_result.component = component_name
740
+ result_type = "pe" if component_name == "pe" else "accelerator"
741
+
742
+ result = (result_type, analysis_result.__dict__)
743
+
744
+ if export_post_synth and output_path:
745
+ os.makedirs((postsynth_dir / output_path).parent, exist_ok=True)
746
+ export_to_verilog(block, postsynth_dir / output_path)
747
+
748
+ return result
749
+ except Exception as e:
750
+ error_msg = f"Error processing {name}: {str(e)}\n{traceback.format_exc()}"
751
+ print(error_msg)
752
+ with open(analysis_result_dir / "errors.log", "a") as f:
753
+ f.write(f"{error_msg}\n{'='*80}\n")
754
+ return None
755
+
756
+
757
+ def save_result(result, analysis_result_dir, result_files):
758
+ """Save a single result to the appropriate CSV file"""
759
+ if not result:
760
+ return
761
+
762
+ result_type, data = result
763
+
764
+ # Get the appropriate file path and headers
765
+ if result_type == "component":
766
+ file_path = analysis_result_dir / "component_analysis.csv"
767
+ elif result_type == "pe":
768
+ file_path = analysis_result_dir / "pe_analysis.csv"
769
+ elif result_type == "accelerator":
770
+ file_path = analysis_result_dir / "accelerator_analysis.csv"
771
+
772
+ # Create directory if it doesn't exist
773
+ os.makedirs(file_path.parent, exist_ok=True)
774
+
775
+ # Check if file exists to determine if we need to write headers
776
+ file_exists = file_path.exists()
777
+
778
+ # Get headers from the data
779
+ headers = list(data.keys())
780
+
781
+ # Open file in append mode
782
+ with open(file_path, "a", newline="") as f:
783
+ writer = csv.DictWriter(f, fieldnames=headers)
784
+
785
+ # Write headers if file doesn't exist
786
+ if not file_exists:
787
+ writer.writeheader()
788
+
789
+ # Write the data row
790
+ writer.writerow(data)
791
+
792
+ # Track that we've written to this file
793
+ result_files.add(file_path)
794
+
795
+
796
+ def result_callback(result, analysis_result_dir, result_files):
797
+ """Callback function for when a process completes"""
798
+ if result:
799
+ save_result(result, analysis_result_dir, result_files)
800
+
801
+
802
+ if __name__ == "__main__":
803
+ # Create a multiprocessing pool with as many processes as CPU cores
804
+ pool = mp.Pool(processes=mp.cpu_count())
805
+
806
+ # Set to track which result files we've written to
807
+ result_files = set()
808
+
809
+ OUTPUT_DIR = Path("verilog")
810
+ POSTSYNTH_DIR = OUTPUT_DIR / "pyrtl_synth"
811
+
812
+ EXPORT_PRE_SYNTH = False
813
+ EXPORT_POST_SYNTH = True
814
+ RUN_ANALYSIS = True
815
+ ANALYSIS_RESULT_DIR = Path("results")
816
+
817
+ # Create output directories
818
+ os.makedirs(OUTPUT_DIR, exist_ok=True)
819
+ os.makedirs(POSTSYNTH_DIR, exist_ok=True)
820
+ os.makedirs(ANALYSIS_RESULT_DIR, exist_ok=True)
821
+
822
+ array_size = 8
823
+ addr_bits = 12
824
+
825
+ dtype_list = [Float8, BF16, Float32]
826
+
827
+ dtype_names = {Float8: "fp8", BF16: "bf16", Float32: "fp32"}
828
+
829
+ weight_act_dtypes = [
830
+ (Float8, Float8),
831
+ (Float8, BF16),
832
+ (Float8, Float32),
833
+ (BF16, BF16),
834
+ (BF16, Float32),
835
+ (Float32, Float32),
836
+ ]
837
+
838
+ # Create a partial function with common arguments
839
+ process_block_partial = partial(
840
+ process_block,
841
+ output_dir=OUTPUT_DIR,
842
+ postsynth_dir=POSTSYNTH_DIR,
843
+ export_pre_synth=EXPORT_PRE_SYNTH,
844
+ export_post_synth=EXPORT_POST_SYNTH,
845
+ run_analysis=RUN_ANALYSIS,
846
+ analysis_result_dir=ANALYSIS_RESULT_DIR,
847
+ dtype_names=dtype_names,
848
+ )
849
+
850
+ # Create a callback function with common arguments
851
+ callback = partial(
852
+ result_callback,
853
+ analysis_result_dir=ANALYSIS_RESULT_DIR,
854
+ result_files=result_files,
855
+ )
856
+
857
+ # Track all submitted tasks
858
+ tasks = []
859
+
860
+ # Hardware building blocks
861
+ for dtype in dtype_list:
862
+ block_dicts = [
863
+ ("adder", create_adder_blocks(dtype)),
864
+ ("multiplier", create_multiplier_blocks(dtype, fast=False)),
865
+ ("multiplier", create_multiplier_blocks(dtype, fast=True)),
866
+ ("lmul", create_lmul_blocks(dtype)),
867
+ ]
868
+
869
+ for component_name, block_dict in block_dicts:
870
+ for name, block in block_dict.items():
871
+ output_path = Path(component_name, dtype_names[dtype], f"{name}.v")
872
+
873
+ # Submit task to process pool
874
+ task = pool.apply_async(
875
+ process_block_partial,
876
+ args=(block, name),
877
+ kwds={
878
+ "component_name": component_name,
879
+ "dtype": dtype,
880
+ "output_path": output_path,
881
+ },
882
+ callback=callback,
883
+ )
884
+ tasks.append(task)
885
+
886
+ # More complex hardware
887
+ for weight_dtype, act_dtype in weight_act_dtypes:
888
+ folder_name = f"w{weight_dtype.bitwidth()}a{act_dtype.bitwidth()}"
889
+
890
+ # Process PE blocks
891
+ pe_blocks = create_pe_blocks((weight_dtype, act_dtype))
892
+ for name, block in pe_blocks.items():
893
+ pe_output_path = Path("pe", folder_name, f"{name}.v")
894
+
895
+ task = pool.apply_async(
896
+ process_block_partial,
897
+ args=(block, name),
898
+ kwds={
899
+ "component_name": "pe",
900
+ "weight_dtype": weight_dtype,
901
+ "act_dtype": act_dtype,
902
+ "output_path": pe_output_path,
903
+ },
904
+ callback=callback,
905
+ )
906
+ tasks.append(task)
907
+
908
+ # Process accelerator blocks
909
+ accelerator_blocks = create_accelerator_blocks(
910
+ (weight_dtype, act_dtype), array_size, addr_bits
911
+ )
912
+ for name, block in accelerator_blocks.items():
913
+ accelerator_output_path = Path("accelerator", folder_name, f"{name}.v")
914
+
915
+ task = pool.apply_async(
916
+ process_block_partial,
917
+ args=(block, name),
918
+ kwds={
919
+ "component_name": "accelerator",
920
+ "weight_dtype": weight_dtype,
921
+ "act_dtype": act_dtype,
922
+ "output_path": accelerator_output_path,
923
+ },
924
+ callback=callback,
925
+ )
926
+ tasks.append(task)
927
+
928
+ # Wait for all tasks to complete
929
+ try:
930
+ # Monitor progress
931
+ total_tasks = len(tasks)
932
+ completed = 0
933
+ print(f"Processing {total_tasks} tasks using {mp.cpu_count()} processes")
934
+
935
+ while completed < total_tasks:
936
+ new_completed = sum(1 for task in tasks if task.ready())
937
+ if new_completed > completed:
938
+ completed = new_completed
939
+ print(
940
+ f"Progress: {completed}/{total_tasks} tasks completed ({completed/total_tasks*100:.1f}%)"
941
+ )
942
+ time.sleep(1)
943
+
944
+ # Make sure all tasks are properly completed
945
+ for task in tasks:
946
+ task.get() # This will raise any exceptions that occurred in the task
947
+
948
+ except KeyboardInterrupt:
949
+ print("Process interrupted by user. Partial results have been saved.")
950
+ except Exception as e:
951
+ print(f"An error occurred: {str(e)}")
952
+ print("Partial results have been saved.")
953
+ finally:
954
+ # Close the pool
955
+ pool.close()
956
+ pool.join()
957
+
958
+ print(f"Results saved to: {', '.join(str(f) for f in result_files)}")
hardware_accelerators/analysis/hardware_stats.py ADDED
@@ -0,0 +1,458 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ from typing import Literal, Dict, Tuple, List, Any, Optional
4
+
5
+ # Mapping from UI options to dataframe values
6
+ DTYPE_MAP = {"float8": "fp8", "bfloat16": "bf16", "float32": "fp32"}
7
+
8
+ SPEED_MAP = {"Fast": True, "Efficient": False}
9
+
10
+ PIPELINE_MAP = {
11
+ "None": "combinational",
12
+ "Low": "combinational",
13
+ "Full": "pipelined",
14
+ "High": "pipelined",
15
+ }
16
+
17
+
18
+ def filter_components(
19
+ df: pd.DataFrame, operation: str, dtype: str, is_fast: bool, architecture: str
20
+ ) -> pd.DataFrame:
21
+ """
22
+ Filter the dataframe to get components matching the specified criteria.
23
+
24
+ Args:
25
+ df: DataFrame containing component data
26
+ operation: Type of operation ('multiplier', 'lmul', 'adder')
27
+ dtype: Data type ('fp8', 'bf16', 'fp32')
28
+ is_fast: Whether to use fast components
29
+ architecture: Architecture type ('combinational', 'pipelined')
30
+
31
+ Returns:
32
+ Filtered DataFrame
33
+ """
34
+ filtered_df = df[
35
+ (df["operation"] == operation)
36
+ & (df["dtype"] == dtype)
37
+ & (df["is_fast"] == is_fast)
38
+ & (df["architecture"] == architecture)
39
+ ]
40
+
41
+ if filtered_df.empty:
42
+ # If no exact match, try without the is_fast constraint
43
+ filtered_df = df[
44
+ (df["operation"] == operation)
45
+ & (df["dtype"] == dtype)
46
+ & (df["architecture"] == architecture)
47
+ ]
48
+
49
+ if filtered_df.empty:
50
+ # If still no match, try without architecture constraint
51
+ filtered_df = df[(df["operation"] == operation) & (df["dtype"] == dtype)]
52
+
53
+ return filtered_df
54
+
55
+
56
+ def get_pe_components(
57
+ df: pd.DataFrame, mult_type: str, dtype: str, is_fast: bool, architecture: str
58
+ ) -> Tuple[pd.Series, pd.Series]:
59
+ """
60
+ Get the multiplier/lmul and adder components for a PE.
61
+
62
+ Args:
63
+ df: DataFrame containing component data
64
+ mult_type: Type of multiplier ('multiplier' or 'lmul')
65
+ dtype: Data type ('fp8', 'bf16', 'fp32')
66
+ is_fast: Whether to use fast components
67
+ architecture: Architecture type ('combinational', 'pipelined')
68
+
69
+ Returns:
70
+ Tuple of (multiplier, adder) Series
71
+ """
72
+ mult_df = filter_components(df, mult_type, dtype, is_fast, architecture)
73
+ adder_df = filter_components(df, "adder", dtype, is_fast, architecture)
74
+
75
+ if mult_df.empty or adder_df.empty:
76
+ raise ValueError(
77
+ f"Could not find components for {mult_type}, {dtype}, {is_fast}, {architecture}"
78
+ )
79
+
80
+ # Get the first matching component
81
+ multiplier = mult_df.iloc[0]
82
+ adder = adder_df.iloc[0]
83
+
84
+ return multiplier, adder
85
+
86
+
87
+ def calculate_pe_metrics(multiplier: pd.Series, adder: pd.Series) -> Dict[str, float]:
88
+ """
89
+ Calculate metrics for a single PE (Processing Element).
90
+
91
+ Args:
92
+ multiplier: Series containing multiplier component data
93
+ adder: Series containing adder component data
94
+
95
+ Returns:
96
+ Dictionary of PE metrics
97
+ """
98
+ # Area and power are additive
99
+ area = multiplier["area"] + adder["area"]
100
+ power = multiplier["power"] + adder["power"]
101
+
102
+ # Critical path depends on architecture
103
+ if (
104
+ multiplier["architecture"] == "combinational"
105
+ and adder["architecture"] == "combinational"
106
+ ):
107
+ delay = max(multiplier["max_arrival_time"], adder["max_arrival_time"])
108
+ pipeline_stages = 1
109
+ else:
110
+ # For pipelined designs, we assume the critical path is the slowest stage
111
+ delay = max(multiplier["max_arrival_time"], adder["max_arrival_time"])
112
+ # Estimate pipeline stages
113
+ if multiplier["architecture"] == "pipelined":
114
+ mult_stages = 4 # Assumption for pipelined multiplier
115
+ else:
116
+ mult_stages = 1
117
+
118
+ if adder["architecture"] == "pipelined":
119
+ add_stages = 5 # Assumption for pipelined adder
120
+ else:
121
+ add_stages = 1
122
+
123
+ pipeline_stages = mult_stages + add_stages - 1 # -1 because they share a stage
124
+
125
+ # Calculate performance metrics
126
+ clock_freq_ghz = 1.0 / delay # GHz, assuming delay is in ns
127
+ ops_per_cycle = 2 # 1 multiply + 1 add = 2 FLOPs per cycle
128
+ tflops = clock_freq_ghz * ops_per_cycle / 1000 # TFLOPS for a single PE
129
+
130
+ # Efficiency metrics
131
+ tflops_per_watt = tflops / power
132
+ tflops_per_mm2 = tflops / (area * 1e-6) # Convert area to mm²
133
+ power_density = power / (area * 1e-6) # W/mm²
134
+
135
+ return {
136
+ "area_um2": area,
137
+ "area_mm2": area * 1e-6,
138
+ "power_w": power,
139
+ "delay_ns": delay,
140
+ "clock_freq_ghz": clock_freq_ghz,
141
+ "pipeline_stages": pipeline_stages,
142
+ "tflops": tflops,
143
+ "tflops_per_watt": tflops_per_watt,
144
+ "tflops_per_mm2": tflops_per_mm2,
145
+ "power_density": power_density,
146
+ "energy_per_op_pj": (power / (clock_freq_ghz * ops_per_cycle * 1e9))
147
+ * 1e12, # pJ per operation
148
+ }
149
+
150
+
151
+ def calculate_array_metrics(
152
+ pe_metrics: Dict[str, float], array_size: int, num_cores: int
153
+ ) -> Dict[str, float]:
154
+ """
155
+ Calculate metrics for a systolic array and the entire accelerator.
156
+
157
+ Args:
158
+ pe_metrics: Dictionary of PE metrics
159
+ array_size: Size of the systolic array (NxN)
160
+ num_cores: Number of accelerator cores
161
+
162
+ Returns:
163
+ Dictionary of array and accelerator metrics
164
+ """
165
+ # Number of PEs per array and total
166
+ pes_per_array = array_size * array_size
167
+ total_pes = pes_per_array * num_cores
168
+
169
+ # Scale metrics for a single array
170
+ array_area_mm2 = pe_metrics["area_mm2"] * pes_per_array
171
+ array_power_w = pe_metrics["power_w"] * pes_per_array
172
+
173
+ # Scale metrics for the entire accelerator
174
+ total_area_mm2 = array_area_mm2 * num_cores
175
+ total_power_w = array_power_w * num_cores
176
+
177
+ # Performance scales with the number of PEs
178
+ array_tflops = pe_metrics["tflops"] * pes_per_array
179
+ total_tflops = array_tflops * num_cores
180
+
181
+ # Efficiency metrics
182
+ total_tflops_per_watt = total_tflops / total_power_w
183
+ total_tflops_per_mm2 = total_tflops / total_area_mm2
184
+
185
+ # Latency calculation
186
+ # For an NxN array, data takes 2N-1 cycles to flow through
187
+ # Plus pipeline_stages-1 cycles for the pipeline to fill
188
+ pipeline_latency_cycles = pe_metrics["pipeline_stages"] - 1
189
+ array_latency_cycles = 2 * array_size - 1
190
+ total_latency_cycles = array_latency_cycles + pipeline_latency_cycles
191
+
192
+ # Time per cycle based on clock frequency
193
+ cycle_time_ns = 1.0 / pe_metrics["clock_freq_ghz"]
194
+
195
+ # Total latency in ns
196
+ total_latency_ns = total_latency_cycles * cycle_time_ns
197
+
198
+ # Throughput after pipeline is filled (ops per second)
199
+ throughput_ops_per_second = (
200
+ pe_metrics["clock_freq_ghz"] * 1e9 * pes_per_array * 2
201
+ ) # 2 ops per PE per cycle
202
+ total_throughput_ops_per_second = throughput_ops_per_second * num_cores
203
+
204
+ # Energy per matrix multiplication
205
+ # Assuming an NxN matrix multiply requires N³ operations
206
+ ops_per_matmul = array_size**3
207
+ energy_per_matmul_nj = (
208
+ (array_power_w / throughput_ops_per_second) * ops_per_matmul * 1e9
209
+ ) # nJ
210
+
211
+ # Inference metrics (assuming a simple MLP with 3 layers)
212
+ # Each layer requires a matrix multiplication
213
+ num_layers = 3
214
+ inference_ops = ops_per_matmul * num_layers
215
+ inference_latency_ns = (inference_ops / throughput_ops_per_second) * 1e9
216
+ inference_energy_uj = (
217
+ (total_power_w / total_throughput_ops_per_second) * inference_ops * 1e6
218
+ ) # uJ
219
+
220
+ return {
221
+ "array_size": array_size,
222
+ "num_cores": num_cores,
223
+ "pes_per_array": pes_per_array,
224
+ "total_pes": total_pes,
225
+ "clock_freq_ghz": pe_metrics["clock_freq_ghz"],
226
+ "array_area_mm2": array_area_mm2,
227
+ "total_area_mm2": total_area_mm2,
228
+ "array_power_w": array_power_w,
229
+ "total_power_w": total_power_w,
230
+ "array_tflops": array_tflops,
231
+ "total_tflops": total_tflops,
232
+ "tflops_per_watt": total_tflops_per_watt,
233
+ "tflops_per_mm2": total_tflops_per_mm2,
234
+ "power_density_w_mm2": total_power_w / total_area_mm2,
235
+ "total_latency_cycles": total_latency_cycles,
236
+ "total_latency_ns": total_latency_ns,
237
+ "throughput_gops": total_throughput_ops_per_second / 1e9, # GOPS
238
+ "energy_per_matmul_nj": energy_per_matmul_nj,
239
+ "inference_latency_ns": inference_latency_ns,
240
+ "inference_latency_us": inference_latency_ns / 1e3, # us
241
+ "inference_energy_uj": inference_energy_uj,
242
+ "inferences_per_second": 1e9 / inference_latency_ns,
243
+ "inferences_per_watt": (1e9 / inference_latency_ns) / total_power_w,
244
+ }
245
+
246
+
247
+ def format_metrics_for_display(metrics: Dict[str, float]) -> Dict[str, str]:
248
+ """
249
+ Format metrics for display in the Gradio UI.
250
+
251
+ Args:
252
+ metrics: Dictionary of metrics
253
+
254
+ Returns:
255
+ Dictionary of formatted metrics
256
+ """
257
+ formatted = {}
258
+
259
+ # Format area
260
+ formatted["Total Chip Area"] = f"{metrics['total_area_mm2']:.2f} mm²"
261
+
262
+ # Format performance
263
+ formatted["Clock Speed"] = f"{metrics['clock_freq_ghz']:.2f} GHz"
264
+ formatted["Total Performance"] = f"{metrics['total_tflops']:.2f} TFLOPS"
265
+ formatted["Performance per Core"] = f"{metrics['array_tflops']:.2f} TFLOPS"
266
+ formatted["Performance per Watt"] = f"{metrics['tflops_per_watt']:.2f} TFLOPS/W"
267
+ formatted["Performance per Area"] = f"{metrics['tflops_per_mm2']:.2f} TFLOPS/mm²"
268
+
269
+ # Format power
270
+ formatted["Total Power"] = f"{metrics['total_power_w']:.2f} W"
271
+ formatted["Power Density"] = f"{metrics['power_density_w_mm2']:.2f} W/mm²"
272
+
273
+ # Format latency and throughput
274
+ formatted["Matrix Mult Latency"] = f"{metrics['total_latency_ns']:.2f} ns"
275
+ formatted["Inference Latency"] = f"{metrics['inference_latency_us']:.2f} µs"
276
+ formatted["Throughput"] = f"{metrics['throughput_gops']:.2f} GOPS"
277
+
278
+ # Format energy
279
+ formatted["Energy per Matrix Mult"] = f"{metrics['energy_per_matmul_nj']:.2f} nJ"
280
+ # formatted["Inference Energy"] = f"{metrics['inference_energy_uj']:.2f} µJ"
281
+ # formatted["Inferences per Second"] = f"{metrics['inferences_per_second']:.0f}"
282
+ # formatted["Inferences per Watt"] = f"{metrics['inferences_per_watt']:.0f}"
283
+
284
+ return formatted
285
+
286
+
287
+ def calculate_hardware_stats(
288
+ df: pd.DataFrame,
289
+ activation_type: Literal["float8", "bfloat16", "float32"],
290
+ weight_type: Literal["float8", "bfloat16", "float32"],
291
+ systolic_array_size: int,
292
+ num_accelerator_cores: int,
293
+ fast_internals: Literal["Fast", "Efficient"],
294
+ pipeline_level: Literal["None", "Low", "Full"],
295
+ process_node_size: Optional[Literal["7nm", "45nm", "130nm"]] = None,
296
+ ) -> Tuple[Dict[str, str], Dict[str, str]]:
297
+ """
298
+ Calculate hardware statistics for both lmul and standard IEEE multiplier configurations.
299
+
300
+ Args:
301
+ df: DataFrame containing component data
302
+ activation_type: Type of activations
303
+ weight_type: Type of weights
304
+ systolic_array_size: Size of the systolic array
305
+ num_accelerator_cores: Number of accelerator cores
306
+ fast_internals: Whether to use fast or efficient components
307
+ pipeline_level: Level of pipelining
308
+ process_node_size: Process node size (ignored for now)
309
+
310
+ Returns:
311
+ Tuple of (lmul_metrics, ieee_metrics) dictionaries
312
+ """
313
+ # Map UI options to dataframe values
314
+ act_dtype = DTYPE_MAP[activation_type]
315
+ weight_dtype = DTYPE_MAP[weight_type]
316
+ is_fast = SPEED_MAP[fast_internals]
317
+ architecture = PIPELINE_MAP[pipeline_level]
318
+
319
+ # For mixed precision, use the larger precision for the PE
320
+ pe_dtype = (
321
+ act_dtype
322
+ if DTYPE_MAP[activation_type] >= DTYPE_MAP[weight_type]
323
+ else weight_dtype
324
+ )
325
+
326
+ # Calculate metrics for lmul configuration
327
+ try:
328
+ lmul_mult, lmul_adder = get_pe_components(
329
+ df, "lmul", pe_dtype, is_fast, architecture
330
+ )
331
+ lmul_pe_metrics = calculate_pe_metrics(lmul_mult, lmul_adder)
332
+ lmul_array_metrics = calculate_array_metrics(
333
+ lmul_pe_metrics, systolic_array_size, num_accelerator_cores
334
+ )
335
+ lmul_formatted = format_metrics_for_display(lmul_array_metrics)
336
+ except ValueError as e:
337
+ # If lmul components not found, return error message
338
+ lmul_formatted = {"Error": f"Could not find lmul components: {str(e)}"}
339
+
340
+ # Calculate metrics for standard IEEE multiplier configuration
341
+ try:
342
+ ieee_mult, ieee_adder = get_pe_components(
343
+ df, "multiplier", pe_dtype, is_fast, architecture
344
+ )
345
+ ieee_pe_metrics = calculate_pe_metrics(ieee_mult, ieee_adder)
346
+ ieee_array_metrics = calculate_array_metrics(
347
+ ieee_pe_metrics, systolic_array_size, num_accelerator_cores
348
+ )
349
+ ieee_formatted = format_metrics_for_display(ieee_array_metrics)
350
+ except ValueError as e:
351
+ # If IEEE components not found, return error message
352
+ ieee_formatted = {"Error": f"Could not find IEEE components: {str(e)}"}
353
+
354
+ return lmul_formatted, ieee_formatted
355
+
356
+
357
+ def calculate_comparison_metrics(
358
+ lmul_metrics: Dict[str, str], ieee_metrics: Dict[str, str]
359
+ ) -> Dict[str, str]:
360
+ """
361
+ Calculate comparison metrics between lmul and IEEE configurations.
362
+
363
+ Args:
364
+ lmul_metrics: Dictionary of lmul metrics
365
+ ieee_metrics: Dictionary of IEEE metrics
366
+
367
+ Returns:
368
+ Dictionary of comparison metrics
369
+ """
370
+ comparison = {}
371
+
372
+ # Check if there was an error in either calculation
373
+ if "Error" in lmul_metrics or "Error" in ieee_metrics:
374
+ return {"Error": "Cannot calculate comparison due to missing components"}
375
+
376
+ # Extract numeric values from formatted strings
377
+ def extract_number(s):
378
+ return float(s.split()[0])
379
+
380
+ # Calculate percentage improvements
381
+ try:
382
+ # Area improvement (lower is better)
383
+ lmul_area = extract_number(lmul_metrics["Total Chip Area"])
384
+ ieee_area = extract_number(ieee_metrics["Total Chip Area"])
385
+ area_improvement = (1 - lmul_area / ieee_area) * 100
386
+ comparison["Area Reduction"] = f"{area_improvement:.1f}%"
387
+
388
+ # Performance improvement (higher is better)
389
+ lmul_perf = extract_number(lmul_metrics["Total Performance"])
390
+ ieee_perf = extract_number(ieee_metrics["Total Performance"])
391
+ perf_improvement = (lmul_perf / ieee_perf - 1) * 100
392
+ comparison["Performance Improvement"] = f"{perf_improvement:.1f}%"
393
+
394
+ # Power efficiency improvement (higher is better)
395
+ lmul_eff = extract_number(lmul_metrics["Performance per Watt"])
396
+ ieee_eff = extract_number(ieee_metrics["Performance per Watt"])
397
+ eff_improvement = (lmul_eff / ieee_eff - 1) * 100
398
+ comparison["Efficiency Improvement"] = f"{eff_improvement:.1f}%"
399
+
400
+ # Latency improvement (lower is better)
401
+ lmul_latency = extract_number(lmul_metrics["Inference Latency"])
402
+ ieee_latency = extract_number(ieee_metrics["Inference Latency"])
403
+ latency_improvement = (1 - lmul_latency / ieee_latency) * 100
404
+ comparison["Latency Reduction"] = f"{latency_improvement:.1f}%"
405
+
406
+ # Energy improvement (lower is better)
407
+ lmul_energy = extract_number(lmul_metrics["Inference Energy"])
408
+ ieee_energy = extract_number(ieee_metrics["Inference Energy"])
409
+ energy_improvement = (1 - lmul_energy / ieee_energy) * 100
410
+ comparison["Energy Reduction"] = f"{energy_improvement:.1f}%"
411
+
412
+ except (ValueError, KeyError) as e:
413
+ comparison["Error"] = f"Error calculating comparisons: {str(e)}"
414
+
415
+ return comparison
416
+
417
+
418
+ # Example usage in the Gradio app:
419
+ def update_hardware_stats(
420
+ df: pd.DataFrame,
421
+ activation_type: str,
422
+ weight_type: str,
423
+ systolic_array_size: int,
424
+ num_accelerator_cores: int,
425
+ fast_internals: str,
426
+ pipeline_level: str,
427
+ process_node_size: str,
428
+ ) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]:
429
+ """
430
+ Update hardware statistics for the Gradio app.
431
+
432
+ Args:
433
+ df: DataFrame containing component data
434
+ activation_type: Type of activations
435
+ weight_type: Type of weights
436
+ systolic_array_size: Size of the systolic array
437
+ num_accelerator_cores: Number of accelerator cores
438
+ fast_internals: Whether to use fast or efficient components
439
+ pipeline_level: Level of pipelining
440
+ process_node_size: Process node size
441
+
442
+ Returns:
443
+ Tuple of (lmul_metrics, ieee_metrics, comparison_metrics) dictionaries
444
+ """
445
+ lmul_metrics, ieee_metrics = calculate_hardware_stats(
446
+ df,
447
+ activation_type,
448
+ weight_type,
449
+ systolic_array_size,
450
+ num_accelerator_cores,
451
+ fast_internals,
452
+ pipeline_level,
453
+ process_node_size,
454
+ )
455
+
456
+ comparison_metrics = calculate_comparison_metrics(lmul_metrics, ieee_metrics)
457
+
458
+ return lmul_metrics, ieee_metrics, comparison_metrics
hardware_accelerators/analysis/mnist_eval.py ADDED
@@ -0,0 +1,274 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Evaluation function
2
+ from itertools import product
3
+ import os
4
+ import pandas as pd
5
+ from pyrtl import WireVector
6
+ import torch
7
+ from typing import Callable
8
+ import numpy as np
9
+ from torchvision import datasets, transforms
10
+ from torch.utils.data import DataLoader
11
+ from torch.nn import CrossEntropyLoss
12
+ import multiprocessing as mp
13
+ from tqdm.auto import tqdm
14
+ import time
15
+ import csv
16
+ from pathlib import Path
17
+ import traceback
18
+
19
+ from ..simulation.matrix_utils import count_batch_gemm_tiles
20
+
21
+ from .config import (
22
+ NN_TEST_MUL_FNS,
23
+ NN_TEST_SYSTOLIC_ARRAY_SIZE,
24
+ NN_TEST_ACCUM_ADDR_WIDTH,
25
+ NN_TEST_WA_DTYPES,
26
+ NN_TEST_BATCH_SIZE,
27
+ )
28
+ from hardware_accelerators.dtypes import *
29
+ from hardware_accelerators.simulation.accelerator import CompiledAcceleratorSimulator
30
+ from hardware_accelerators.rtllib.accelerator import CompiledAcceleratorConfig
31
+ from hardware_accelerators.rtllib.multipliers import *
32
+ from hardware_accelerators.nn import load_model
33
+ from ..simulation.accelerator import CompiledAcceleratorSimulator
34
+
35
+
36
+ def generate_test_configs(
37
+ weight_act_dtypes: list[tuple[Type[BaseFloat], Type[BaseFloat]]],
38
+ multiplier_fns: list[
39
+ Callable[[WireVector, WireVector, type[BaseFloat]], WireVector]
40
+ ],
41
+ ):
42
+ configs = []
43
+ for dtypes, mult_fn in product(weight_act_dtypes, multiplier_fns):
44
+ weight_type, act_type = dtypes
45
+ config = CompiledAcceleratorConfig(
46
+ array_size=NN_TEST_SYSTOLIC_ARRAY_SIZE,
47
+ weight_type=weight_type,
48
+ activation_type=act_type,
49
+ multiplier=mult_fn,
50
+ accum_addr_width=NN_TEST_ACCUM_ADDR_WIDTH,
51
+ )
52
+ configs.append(config)
53
+ return configs
54
+
55
+
56
+ def evaluate_with_progress(
57
+ config,
58
+ dataset,
59
+ batch_size,
60
+ criterion=CrossEntropyLoss(),
61
+ process_id=0,
62
+ ):
63
+ """Evaluate a model with progress tracking for the entire dataset"""
64
+ # Define a complete result template with default values
65
+ result = {
66
+ "config": config.name,
67
+ "weight_type": config.weight_type.__name__,
68
+ "activation_type": config.activation_type.__name__,
69
+ "multiplier": config.multiplier.__name__,
70
+ "avg_loss": float("nan"),
71
+ "accuracy": float("nan"),
72
+ "total_time": 0,
73
+ "batch_size": batch_size,
74
+ "total_batches": 0,
75
+ "total_samples": 0,
76
+ "samples_per_second": 0,
77
+ "error": None, # Will be None for successful runs
78
+ }
79
+
80
+ try:
81
+ start_time = time.time()
82
+
83
+ # Load the appropriate model based on weight type
84
+ if config.weight_type == Float32:
85
+ model = load_model("./models/mlp_mnist_fp32.pth")
86
+ else:
87
+ model = load_model("./models/mlp_mnist_bf16.pth")
88
+
89
+ # Create simulator
90
+ sim = CompiledAcceleratorSimulator(config, model=model)
91
+
92
+ if not sim.model_loaded:
93
+ result["error"] = "No model loaded"
94
+ return result
95
+
96
+ correct = 0
97
+ total = 0
98
+ running_loss = 0.0
99
+
100
+ data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
101
+ total_batches = len(data_loader)
102
+ tiles_per_batch = count_batch_gemm_tiles(
103
+ sim.hidden_dim, sim.input_dim + 1, sim.config.array_size
104
+ ) + count_batch_gemm_tiles(
105
+ sim.output_dim, sim.hidden_dim + 1, sim.config.array_size
106
+ )
107
+ result["total_batches"] = total_batches
108
+
109
+ # Create a progress bar for this specific simulation
110
+ desc = f"Config {config.name} ({config.weight_type.__name__}/{config.activation_type.__name__})"
111
+ with tqdm(
112
+ total=total_batches * tiles_per_batch,
113
+ desc=desc,
114
+ position=process_id + 1,
115
+ leave=False,
116
+ ) as pbar:
117
+ for batch, labels in data_loader:
118
+ batch_size_actual = batch.shape[0]
119
+ batch = batch.reshape(batch_size_actual, -1).numpy()
120
+
121
+ # Time the prediction
122
+ outputs = sim.predict_batch(batch, pbar)
123
+
124
+ loss = criterion(torch.tensor(outputs), labels)
125
+ running_loss += loss.item()
126
+
127
+ # Get predictions from the maximum value
128
+ predicted = np.argmax(outputs, axis=1)
129
+ total += labels.size(0)
130
+ correct += (predicted == labels).sum().item()
131
+
132
+ end_time = time.time()
133
+ total_time = end_time - start_time
134
+
135
+ # Update result with actual values
136
+ result.update(
137
+ {
138
+ "avg_loss": running_loss / total_batches,
139
+ "accuracy": 100.0 * correct / total,
140
+ "total_time": total_time,
141
+ "total_samples": total,
142
+ "samples_per_second": total / total_time,
143
+ }
144
+ )
145
+
146
+ return result
147
+
148
+ except Exception as e:
149
+ error_msg = (
150
+ f"Error evaluating {config.name}: {str(e)}\n{traceback.format_exc()}"
151
+ )
152
+ print(error_msg)
153
+ result["error"] = str(e)
154
+ return result
155
+
156
+
157
+ def save_result(result, output_file):
158
+ """Save a single result to CSV file"""
159
+ file_exists = os.path.isfile(output_file)
160
+
161
+ with open(output_file, "a", newline="") as f:
162
+ writer = csv.DictWriter(f, fieldnames=list(result.keys()))
163
+ if not file_exists:
164
+ writer.writeheader()
165
+ writer.writerow(result)
166
+
167
+
168
+ def process_config(config, dataset, batch_size, output_file, process_id):
169
+ """Process a single configuration and save results"""
170
+ # Set process name for better monitoring
171
+ mp.current_process().name = f"Sim-{config.name}"
172
+
173
+ # print(f"Starting evaluation of {config.name}")
174
+ result = evaluate_with_progress(config, dataset, batch_size, process_id=process_id)
175
+
176
+ # Save result immediately
177
+ save_result(result, output_file)
178
+
179
+ print(
180
+ f"Completed evaluation of {config.name}: Accuracy = {result.get('accuracy', 'ERROR'):.2f}%, "
181
+ f"Time = {result.get('total_time', 0):.2f}s, "
182
+ f"Speed = {result.get('samples_per_second', 0):.2f} samples/s"
183
+ )
184
+ return result
185
+
186
+
187
+ def main():
188
+ # Create output directory
189
+ output_dir = Path("results")
190
+ output_dir.mkdir(exist_ok=True)
191
+ output_file = output_dir / "mnist_eval.csv"
192
+
193
+ # Data transformation: convert images to tensor and normalize them
194
+ transform = transforms.Compose(
195
+ [
196
+ transforms.ToTensor(),
197
+ transforms.Normalize((0.1307,), (0.3081,)),
198
+ ]
199
+ )
200
+
201
+ # Download MNIST test data
202
+ test_dataset = datasets.MNIST(
203
+ root="./data", train=False, download=True, transform=transform
204
+ )
205
+
206
+ configs = generate_test_configs(NN_TEST_WA_DTYPES, NN_TEST_MUL_FNS)
207
+
208
+ # Create a multiprocessing pool
209
+ num_processes = min(len(configs), mp.cpu_count() - 2)
210
+ print(f"Using {num_processes} processes to evaluate {len(configs)} configurations")
211
+ print(
212
+ f"Each simulation will process the entire test dataset with batch size {NN_TEST_BATCH_SIZE}"
213
+ )
214
+
215
+ # Clear the screen and set up for progress bars
216
+ print("\n\n")
217
+
218
+ # Start the pool and process configurations
219
+ with mp.Pool(processes=num_processes) as pool:
220
+ # Create a list to track all tasks
221
+ tasks = []
222
+
223
+ # Submit all tasks
224
+ for i, config in enumerate(configs):
225
+ task = pool.apply_async(
226
+ process_config,
227
+ args=(config, test_dataset, NN_TEST_BATCH_SIZE, output_file, i),
228
+ )
229
+ tasks.append((config.name, task))
230
+
231
+ # Set up the main progress bar at the top
232
+ with tqdm(total=len(tasks), desc="Overall Progress", position=0) as pbar:
233
+ completed = 0
234
+ while completed < len(tasks):
235
+ new_completed = sum(1 for _, task in tasks if task.ready())
236
+ if new_completed > completed:
237
+ pbar.update(new_completed - completed)
238
+ completed = new_completed
239
+ time.sleep(0.5)
240
+
241
+ # Make sure all tasks are properly completed and collect results
242
+ all_results = []
243
+ for config_name, task in tasks:
244
+ try:
245
+ result = task.get()
246
+ all_results.append(result)
247
+ except Exception as e:
248
+ print(f"Error in task {config_name}: {str(e)}")
249
+
250
+ print(f"All evaluations complete. Results saved to {output_file}")
251
+
252
+ # Create a summary DataFrame and display it
253
+ if all_results:
254
+ df = pd.DataFrame(all_results)
255
+ print("\nSummary of Results (sorted by accuracy):")
256
+ summary_cols = [
257
+ "config",
258
+ "weight_type",
259
+ "activation_type",
260
+ "multiplier",
261
+ "accuracy",
262
+ "total_time",
263
+ "samples_per_second",
264
+ ]
265
+ print(df[summary_cols].sort_values("accuracy", ascending=False))
266
+
267
+ print("\nSummary of Results (sorted by speed):")
268
+ print(df[summary_cols].sort_values("samples_per_second", ascending=False))
269
+
270
+
271
+ if __name__ == "__main__":
272
+ # Set start method for multiprocessing
273
+ mp.set_start_method("spawn", force=True) # Use 'spawn' for better compatibility
274
+ main()
hardware_accelerators/analysis/simple_circuits.py ADDED
@@ -0,0 +1,258 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pyrtl
2
+ from pyrtl import WireVector, Input, Output, Simulation
3
+ import numpy as np
4
+ import sys
5
+ import os
6
+
7
+ # Add the parent directory to the path so we can import from hardware_accelerators
8
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
9
+
10
+ from hardware_accelerators.dtypes import Float16, Float32, Float8, BF16
11
+
12
+
13
+ def create_simple_adder(data_type):
14
+ """
15
+ Create a simple adder circuit (a + b) using PyRTL's built-in operators.
16
+
17
+ Args:
18
+ data_type: The data type to use (only used for bitwidth)
19
+
20
+ Returns:
21
+ The PyRTL working block
22
+ """
23
+ # Clear any existing PyRTL design
24
+ pyrtl.reset_working_block()
25
+
26
+ # Create input and output wires
27
+ a = Input(data_type.bitwidth(), 'a')
28
+ b = Input(data_type.bitwidth(), 'b')
29
+ result = Output(data_type.bitwidth(), 'result')
30
+
31
+ # Create adder using PyRTL's built-in addition
32
+ # Note: This treats the inputs as unsigned integers, not floating point
33
+ result <<= a + b
34
+
35
+ return pyrtl.working_block()
36
+
37
+
38
+ def create_simple_multiplier(data_type):
39
+ """
40
+ Create a simple multiplier circuit (a * b) using PyRTL's built-in operators.
41
+
42
+ Args:
43
+ data_type: The data type to use (only used for bitwidth)
44
+
45
+ Returns:
46
+ The PyRTL working block
47
+ """
48
+ # Clear any existing PyRTL design
49
+ pyrtl.reset_working_block()
50
+
51
+ # Create input and output wires
52
+ a = Input(data_type.bitwidth(), 'a')
53
+ b = Input(data_type.bitwidth(), 'b')
54
+ result = Output(data_type.bitwidth(), 'result')
55
+
56
+ # Create multiplier using PyRTL's built-in multiplication
57
+ # Note: This treats the inputs as unsigned integers, not floating point
58
+ # We'll truncate the result to match the input bitwidth
59
+ mult_result = a * b
60
+ result <<= mult_result[:data_type.bitwidth()]
61
+
62
+ return pyrtl.working_block()
63
+
64
+
65
+ def create_pipelined_adder(data_type):
66
+ """
67
+ Create a pipelined adder circuit (a + b) using PyRTL's built-in operators.
68
+
69
+ Args:
70
+ data_type: The data type to use (only used for bitwidth)
71
+
72
+ Returns:
73
+ The PyRTL working block and the result wire
74
+ """
75
+ # Clear any existing PyRTL design
76
+ pyrtl.reset_working_block()
77
+
78
+ # Create input and output wires
79
+ a = Input(data_type.bitwidth(), 'a')
80
+ b = Input(data_type.bitwidth(), 'b')
81
+ result = Output(data_type.bitwidth(), 'result')
82
+
83
+ # Create pipeline registers
84
+ a_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='a_reg')
85
+ b_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='b_reg')
86
+
87
+ # Connect input to registers
88
+ a_reg.next <<= a
89
+ b_reg.next <<= b
90
+
91
+ # Perform addition in the next stage using PyRTL's built-in addition
92
+ add_result = a_reg + b_reg
93
+
94
+ # Connect to output
95
+ result <<= add_result
96
+
97
+ return pyrtl.working_block(), result
98
+
99
+
100
+ def create_pipelined_multiplier(data_type):
101
+ """
102
+ Create a pipelined multiplier circuit (a * b) using PyRTL's built-in operators.
103
+
104
+ Args:
105
+ data_type: The data type to use (only used for bitwidth)
106
+
107
+ Returns:
108
+ The PyRTL working block and the result wire
109
+ """
110
+ # Clear any existing PyRTL design
111
+ pyrtl.reset_working_block()
112
+
113
+ # Create input and output wires
114
+ a = Input(data_type.bitwidth(), 'a')
115
+ b = Input(data_type.bitwidth(), 'b')
116
+ result = Output(data_type.bitwidth(), 'result')
117
+
118
+ # Create pipeline registers
119
+ a_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='a_reg')
120
+ b_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='b_reg')
121
+
122
+ # Connect input to registers
123
+ a_reg.next <<= a
124
+ b_reg.next <<= b
125
+
126
+ # Perform multiplication in the next stage using PyRTL's built-in multiplication
127
+ mult_result = a_reg * b_reg
128
+
129
+ # Truncate the result to match the input bitwidth
130
+ truncated_result = mult_result[:data_type.bitwidth()]
131
+
132
+ # Connect to output
133
+ result <<= truncated_result
134
+
135
+ return pyrtl.working_block(), result
136
+
137
+
138
+ def simulate_circuit(block, data_type, num_cycles=10):
139
+ """
140
+ Simulate a circuit with random inputs.
141
+
142
+ Args:
143
+ block: The PyRTL block to simulate
144
+ data_type: The data type used (only for bitwidth)
145
+ num_cycles: Number of simulation cycles
146
+
147
+ Returns:
148
+ sim: The simulation object
149
+ trace: The simulation trace
150
+ """
151
+ # Create simulation with tracing enabled
152
+ sim = Simulation()
153
+
154
+ # Create test data (random integers within the valid range for the bitwidth)
155
+ bitwidth = data_type.bitwidth()
156
+
157
+ # Handle large bitwidths safely
158
+ if bitwidth > 30: # Avoid overflow for large bitwidths
159
+ a_values = [np.random.randint(0, 2**16) for _ in range(num_cycles)]
160
+ b_values = [np.random.randint(0, 2**16) for _ in range(num_cycles)]
161
+ else:
162
+ max_val = 2**bitwidth - 1
163
+ a_values = [np.random.randint(0, max_val) for _ in range(num_cycles)]
164
+ b_values = [np.random.randint(0, max_val) for _ in range(num_cycles)]
165
+
166
+ # Create input dictionaries for each cycle
167
+ input_vectors = []
168
+ for i in range(num_cycles):
169
+ cycle_inputs = {
170
+ 'a': a_values[i],
171
+ 'b': b_values[i]
172
+ }
173
+ input_vectors.append(cycle_inputs)
174
+
175
+ # Run simulation for each cycle
176
+ for i in range(num_cycles):
177
+ sim.step(input_vectors[i])
178
+
179
+ # Get trace from the tracer
180
+ tracer = sim.tracer
181
+
182
+ return sim, tracer
183
+
184
+
185
+ def main():
186
+ """Main function to create and test the simplified circuits."""
187
+ # Data types to test (only used for bitwidth)
188
+ data_types = [Float8, BF16, Float16, Float32]
189
+
190
+ # Results dictionary
191
+ results = {
192
+ "simple_adder": {},
193
+ "simple_multiplier": {},
194
+ "pipelined_adder": {},
195
+ "pipelined_multiplier": {}
196
+ }
197
+
198
+ # Test simple adders
199
+ print("=== Testing Simple Adders ===")
200
+ for dtype in data_types:
201
+ print(f"\nCreating and simulating {dtype.__name__} adder (bitwidth: {dtype.bitwidth()})...")
202
+ block = create_simple_adder(dtype)
203
+ sim, tracer = simulate_circuit(block, dtype)
204
+
205
+ # Print some results
206
+ if 'result' in tracer.trace:
207
+ output_values = tracer.trace['result']
208
+ print(f" result: {output_values}")
209
+
210
+ results["simple_adder"][dtype.__name__] = block
211
+
212
+ # Test simple multipliers
213
+ print("\n=== Testing Simple Multipliers ===")
214
+ for dtype in data_types:
215
+ print(f"\nCreating and simulating {dtype.__name__} multiplier (bitwidth: {dtype.bitwidth()})...")
216
+ block = create_simple_multiplier(dtype)
217
+ sim, tracer = simulate_circuit(block, dtype)
218
+
219
+ # Print some results
220
+ if 'result' in tracer.trace:
221
+ output_values = tracer.trace['result']
222
+ print(f" result: {output_values}")
223
+
224
+ results["simple_multiplier"][dtype.__name__] = block
225
+
226
+ # Test pipelined adders
227
+ print("\n=== Testing Pipelined Adders ===")
228
+ for dtype in data_types:
229
+ print(f"\nCreating and simulating {dtype.__name__} pipelined adder (bitwidth: {dtype.bitwidth()})...")
230
+ block, _ = create_pipelined_adder(dtype)
231
+ sim, tracer = simulate_circuit(block, dtype)
232
+
233
+ # Print some results
234
+ if 'result' in tracer.trace:
235
+ output_values = tracer.trace['result']
236
+ print(f" result: {output_values}")
237
+
238
+ results["pipelined_adder"][dtype.__name__] = block
239
+
240
+ # Test pipelined multipliers
241
+ print("\n=== Testing Pipelined Multipliers ===")
242
+ for dtype in data_types:
243
+ print(f"\nCreating and simulating {dtype.__name__} pipelined multiplier (bitwidth: {dtype.bitwidth()})...")
244
+ block, _ = create_pipelined_multiplier(dtype)
245
+ sim, tracer = simulate_circuit(block, dtype)
246
+
247
+ # Print some results
248
+ if 'result' in tracer.trace:
249
+ output_values = tracer.trace['result']
250
+ print(f" result: {output_values}")
251
+
252
+ results["pipelined_multiplier"][dtype.__name__] = block
253
+
254
+ return results
255
+
256
+
257
+ if __name__ == "__main__":
258
+ main()
hardware_accelerators/analysis/verilog_export.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import sys
3
+ import pyrtl
4
+ from pyrtl import WireVector, Input, Output, Simulation
5
+
6
+ # Add the parent directory to the path so we can import from hardware_accelerators
7
+ sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
8
+
9
+ from hardware_accelerators.dtypes import Float16, Float32, Float8, BF16
10
+ from hardware_accelerators.analysis.simple_circuits import (
11
+ create_simple_adder,
12
+ create_simple_multiplier,
13
+ create_pipelined_adder,
14
+ create_pipelined_multiplier
15
+ )
16
+
17
+ def export_to_verilog(block, output_filename, add_reset=True, initialize_registers=False):
18
+ """
19
+ Export a PyRTL block to a Verilog file.
20
+
21
+ Args:
22
+ block: The PyRTL block to export
23
+ output_filename: The filename to write the Verilog code to
24
+ add_reset: If reset logic should be added. Options are:
25
+ False (no reset logic),
26
+ True (synchronous reset logic),
27
+ 'asynchronous' (asynchronous reset logic)
28
+ initialize_registers: Initialize Verilog registers to their reset_value
29
+ """
30
+ # Create the output directory if it doesn't exist
31
+ os.makedirs(os.path.dirname(output_filename), exist_ok=True)
32
+
33
+ # Export the block to Verilog
34
+ with open(output_filename, 'w') as f:
35
+ pyrtl.output_to_verilog(
36
+ f,
37
+ add_reset=add_reset,
38
+ initialize_registers=initialize_registers,
39
+ block=block
40
+ )
41
+
42
+ print(f"Exported Verilog to {output_filename}")
43
+
44
+ def export_all_circuits():
45
+ """
46
+ Export all simple circuits to Verilog files.
47
+ """
48
+ # Create output directory
49
+ output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "verilog_output")
50
+ os.makedirs(output_dir, exist_ok=True)
51
+
52
+ # List of data types to use
53
+ data_types = [Float8, Float16, BF16, Float32]
54
+
55
+ # Export simple adder for each data type
56
+ for dtype in data_types:
57
+ block = create_simple_adder(dtype)
58
+ output_filename = os.path.join(output_dir, f"simple_adder_{dtype.__name__}.v")
59
+ export_to_verilog(block, output_filename)
60
+
61
+ # Export simple multiplier for each data type
62
+ for dtype in data_types:
63
+ block = create_simple_multiplier(dtype)
64
+ output_filename = os.path.join(output_dir, f"simple_multiplier_{dtype.__name__}.v")
65
+ export_to_verilog(block, output_filename)
66
+
67
+ # Export pipelined adder for each data type
68
+ for dtype in data_types:
69
+ block, _ = create_pipelined_adder(dtype)
70
+ output_filename = os.path.join(output_dir, f"pipelined_adder_{dtype.__name__}.v")
71
+ export_to_verilog(block, output_filename, initialize_registers=True)
72
+
73
+ # Export pipelined multiplier for each data type
74
+ for dtype in data_types:
75
+ block, _ = create_pipelined_multiplier(dtype)
76
+ output_filename = os.path.join(output_dir, f"pipelined_multiplier_{dtype.__name__}.v")
77
+ export_to_verilog(block, output_filename, initialize_registers=True)
78
+
79
+ def main():
80
+ # Export all circuits
81
+ export_all_circuits()
82
+
83
+ print("All circuits exported to Verilog successfully!")
84
+
85
+ if __name__ == "__main__":
86
+ main()
hardware_accelerators/analysis/verilog_output/pipelined_adder_BF16.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ reg[15:0] a_reg = 16'd0;
13
+ reg[15:0] b_reg = 16'd0;
14
+
15
+ wire[16:0] tmp20;
16
+ wire[15:0] tmp21;
17
+
18
+ // Combinational
19
+ assign result = tmp21;
20
+ assign tmp20 = a_reg + b_reg;
21
+ assign tmp21 = {tmp20[15], tmp20[14], tmp20[13], tmp20[12], tmp20[11], tmp20[10], tmp20[9], tmp20[8], tmp20[7], tmp20[6], tmp20[5], tmp20[4], tmp20[3], tmp20[2], tmp20[1], tmp20[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/pipelined_adder_Float16.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ reg[15:0] a_reg = 16'd0;
13
+ reg[15:0] b_reg = 16'd0;
14
+
15
+ wire[16:0] tmp18;
16
+ wire[15:0] tmp19;
17
+
18
+ // Combinational
19
+ assign result = tmp19;
20
+ assign tmp18 = a_reg + b_reg;
21
+ assign tmp19 = {tmp18[15], tmp18[14], tmp18[13], tmp18[12], tmp18[11], tmp18[10], tmp18[9], tmp18[8], tmp18[7], tmp18[6], tmp18[5], tmp18[4], tmp18[3], tmp18[2], tmp18[1], tmp18[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/pipelined_adder_Float32.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[31:0] a;
9
+ input[31:0] b;
10
+ output[31:0] result;
11
+
12
+ reg[31:0] a_reg = 32'd0;
13
+ reg[31:0] b_reg = 32'd0;
14
+
15
+ wire[32:0] tmp22;
16
+ wire[31:0] tmp23;
17
+
18
+ // Combinational
19
+ assign result = tmp23;
20
+ assign tmp22 = a_reg + b_reg;
21
+ assign tmp23 = {tmp22[31], tmp22[30], tmp22[29], tmp22[28], tmp22[27], tmp22[26], tmp22[25], tmp22[24], tmp22[23], tmp22[22], tmp22[21], tmp22[20], tmp22[19], tmp22[18], tmp22[17], tmp22[16], tmp22[15], tmp22[14], tmp22[13], tmp22[12], tmp22[11], tmp22[10], tmp22[9], tmp22[8], tmp22[7], tmp22[6], tmp22[5], tmp22[4], tmp22[3], tmp22[2], tmp22[1], tmp22[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/pipelined_adder_Float8.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[7:0] a;
9
+ input[7:0] b;
10
+ output[7:0] result;
11
+
12
+ reg[7:0] a_reg = 8'd0;
13
+ reg[7:0] b_reg = 8'd0;
14
+
15
+ wire[8:0] tmp16;
16
+ wire[7:0] tmp17;
17
+
18
+ // Combinational
19
+ assign result = tmp17;
20
+ assign tmp16 = a_reg + b_reg;
21
+ assign tmp17 = {tmp16[7], tmp16[6], tmp16[5], tmp16[4], tmp16[3], tmp16[2], tmp16[1], tmp16[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_BF16.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ reg[15:0] a_reg = 16'd0;
13
+ reg[15:0] b_reg = 16'd0;
14
+
15
+ wire[31:0] tmp28;
16
+ wire[15:0] tmp29;
17
+
18
+ // Combinational
19
+ assign result = tmp29;
20
+ assign tmp28 = a_reg * b_reg;
21
+ assign tmp29 = {tmp28[15], tmp28[14], tmp28[13], tmp28[12], tmp28[11], tmp28[10], tmp28[9], tmp28[8], tmp28[7], tmp28[6], tmp28[5], tmp28[4], tmp28[3], tmp28[2], tmp28[1], tmp28[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float16.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ reg[15:0] a_reg = 16'd0;
13
+ reg[15:0] b_reg = 16'd0;
14
+
15
+ wire[31:0] tmp26;
16
+ wire[15:0] tmp27;
17
+
18
+ // Combinational
19
+ assign result = tmp27;
20
+ assign tmp26 = a_reg * b_reg;
21
+ assign tmp27 = {tmp26[15], tmp26[14], tmp26[13], tmp26[12], tmp26[11], tmp26[10], tmp26[9], tmp26[8], tmp26[7], tmp26[6], tmp26[5], tmp26[4], tmp26[3], tmp26[2], tmp26[1], tmp26[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float32.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[31:0] a;
9
+ input[31:0] b;
10
+ output[31:0] result;
11
+
12
+ reg[31:0] a_reg = 32'd0;
13
+ reg[31:0] b_reg = 32'd0;
14
+
15
+ wire[63:0] tmp30;
16
+ wire[31:0] tmp31;
17
+
18
+ // Combinational
19
+ assign result = tmp31;
20
+ assign tmp30 = a_reg * b_reg;
21
+ assign tmp31 = {tmp30[31], tmp30[30], tmp30[29], tmp30[28], tmp30[27], tmp30[26], tmp30[25], tmp30[24], tmp30[23], tmp30[22], tmp30[21], tmp30[20], tmp30[19], tmp30[18], tmp30[17], tmp30[16], tmp30[15], tmp30[14], tmp30[13], tmp30[12], tmp30[11], tmp30[10], tmp30[9], tmp30[8], tmp30[7], tmp30[6], tmp30[5], tmp30[4], tmp30[3], tmp30[2], tmp30[1], tmp30[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float8.v ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[7:0] a;
9
+ input[7:0] b;
10
+ output[7:0] result;
11
+
12
+ reg[7:0] a_reg = 8'd0;
13
+ reg[7:0] b_reg = 8'd0;
14
+
15
+ wire[15:0] tmp24;
16
+ wire[7:0] tmp25;
17
+
18
+ // Combinational
19
+ assign result = tmp25;
20
+ assign tmp24 = a_reg * b_reg;
21
+ assign tmp25 = {tmp24[7], tmp24[6], tmp24[5], tmp24[4], tmp24[3], tmp24[2], tmp24[1], tmp24[0]};
22
+
23
+ // Registers
24
+ always @(posedge clk)
25
+ begin
26
+ if (rst) begin
27
+ a_reg <= 0;
28
+ b_reg <= 0;
29
+ end
30
+ else begin
31
+ a_reg <= a;
32
+ b_reg <= b;
33
+ end
34
+ end
35
+
36
+ endmodule
37
+
hardware_accelerators/analysis/verilog_output/simple_adder_BF16.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ wire[16:0] tmp4;
13
+ wire[15:0] tmp5;
14
+
15
+ // Combinational
16
+ assign result = tmp5;
17
+ assign tmp4 = a + b;
18
+ assign tmp5 = {tmp4[15], tmp4[14], tmp4[13], tmp4[12], tmp4[11], tmp4[10], tmp4[9], tmp4[8], tmp4[7], tmp4[6], tmp4[5], tmp4[4], tmp4[3], tmp4[2], tmp4[1], tmp4[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/analysis/verilog_output/simple_adder_Float16.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ wire[16:0] tmp2;
13
+ wire[15:0] tmp3;
14
+
15
+ // Combinational
16
+ assign result = tmp3;
17
+ assign tmp2 = a + b;
18
+ assign tmp3 = {tmp2[15], tmp2[14], tmp2[13], tmp2[12], tmp2[11], tmp2[10], tmp2[9], tmp2[8], tmp2[7], tmp2[6], tmp2[5], tmp2[4], tmp2[3], tmp2[2], tmp2[1], tmp2[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/analysis/verilog_output/simple_adder_Float32.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[31:0] a;
9
+ input[31:0] b;
10
+ output[31:0] result;
11
+
12
+ wire[32:0] tmp6;
13
+ wire[31:0] tmp7;
14
+
15
+ // Combinational
16
+ assign result = tmp7;
17
+ assign tmp6 = a + b;
18
+ assign tmp7 = {tmp6[31], tmp6[30], tmp6[29], tmp6[28], tmp6[27], tmp6[26], tmp6[25], tmp6[24], tmp6[23], tmp6[22], tmp6[21], tmp6[20], tmp6[19], tmp6[18], tmp6[17], tmp6[16], tmp6[15], tmp6[14], tmp6[13], tmp6[12], tmp6[11], tmp6[10], tmp6[9], tmp6[8], tmp6[7], tmp6[6], tmp6[5], tmp6[4], tmp6[3], tmp6[2], tmp6[1], tmp6[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/analysis/verilog_output/simple_adder_Float8.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[7:0] a;
9
+ input[7:0] b;
10
+ output[7:0] result;
11
+
12
+ wire[8:0] tmp0;
13
+ wire[7:0] tmp1;
14
+
15
+ // Combinational
16
+ assign result = tmp1;
17
+ assign tmp0 = a + b;
18
+ assign tmp1 = {tmp0[7], tmp0[6], tmp0[5], tmp0[4], tmp0[3], tmp0[2], tmp0[1], tmp0[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/analysis/verilog_output/simple_multiplier_BF16.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ wire[31:0] tmp12;
13
+ wire[15:0] tmp13;
14
+
15
+ // Combinational
16
+ assign result = tmp13;
17
+ assign tmp12 = a * b;
18
+ assign tmp13 = {tmp12[15], tmp12[14], tmp12[13], tmp12[12], tmp12[11], tmp12[10], tmp12[9], tmp12[8], tmp12[7], tmp12[6], tmp12[5], tmp12[4], tmp12[3], tmp12[2], tmp12[1], tmp12[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/analysis/verilog_output/simple_multiplier_Float16.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[15:0] a;
9
+ input[15:0] b;
10
+ output[15:0] result;
11
+
12
+ wire[31:0] tmp10;
13
+ wire[15:0] tmp11;
14
+
15
+ // Combinational
16
+ assign result = tmp11;
17
+ assign tmp10 = a * b;
18
+ assign tmp11 = {tmp10[15], tmp10[14], tmp10[13], tmp10[12], tmp10[11], tmp10[10], tmp10[9], tmp10[8], tmp10[7], tmp10[6], tmp10[5], tmp10[4], tmp10[3], tmp10[2], tmp10[1], tmp10[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/analysis/verilog_output/simple_multiplier_Float32.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[31:0] a;
9
+ input[31:0] b;
10
+ output[31:0] result;
11
+
12
+ wire[63:0] tmp14;
13
+ wire[31:0] tmp15;
14
+
15
+ // Combinational
16
+ assign result = tmp15;
17
+ assign tmp14 = a * b;
18
+ assign tmp15 = {tmp14[31], tmp14[30], tmp14[29], tmp14[28], tmp14[27], tmp14[26], tmp14[25], tmp14[24], tmp14[23], tmp14[22], tmp14[21], tmp14[20], tmp14[19], tmp14[18], tmp14[17], tmp14[16], tmp14[15], tmp14[14], tmp14[13], tmp14[12], tmp14[11], tmp14[10], tmp14[9], tmp14[8], tmp14[7], tmp14[6], tmp14[5], tmp14[4], tmp14[3], tmp14[2], tmp14[1], tmp14[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/analysis/verilog_output/simple_multiplier_Float8.v ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ // Generated automatically via PyRTL
2
+ // As one initial test of synthesis, map to FPGA with:
3
+ // yosys -p "synth_xilinx -top toplevel" thisfile.v
4
+
5
+ module toplevel(clk, rst, a, b, result);
6
+ input clk;
7
+ input rst;
8
+ input[7:0] a;
9
+ input[7:0] b;
10
+ output[7:0] result;
11
+
12
+ wire[15:0] tmp8;
13
+ wire[7:0] tmp9;
14
+
15
+ // Combinational
16
+ assign result = tmp9;
17
+ assign tmp8 = a * b;
18
+ assign tmp9 = {tmp8[7], tmp8[6], tmp8[5], tmp8[4], tmp8[3], tmp8[2], tmp8[1], tmp8[0]};
19
+
20
+ endmodule
21
+
hardware_accelerators/app.py ADDED
@@ -0,0 +1,388 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import Literal
3
+ import gradio as gr
4
+ from gradio.components.image_editor import EditorValue
5
+ import numpy as np
6
+ import pandas as pd
7
+ from PIL import Image
8
+ from torchvision import transforms
9
+ from .nn.util import load_model
10
+ from .rtllib.lmul import lmul_simple
11
+ from .rtllib.multipliers import float_multiplier
12
+ from .dtypes import Float8, BF16
13
+ from .rtllib import (
14
+ CompiledAcceleratorConfig,
15
+ )
16
+ from .simulation import CompiledAcceleratorSimulator
17
+ from .analysis.hardware_stats import (
18
+ calculate_hardware_stats,
19
+ )
20
+
21
+ __all__ = ["create_app"]
22
+
23
+ # ------------ CONSTANTS ------------ #
24
+
25
+ # Load the component data
26
+ data_path = os.environ.get("COMPONENT_DATA_PATH", "results/component_data.csv")
27
+ DF = pd.read_csv(data_path)
28
+
29
+ # Load the trained model
30
+ MODEL = load_model("models/mlp_mnist.pth", "cpu") # type: ignore
31
+ MODEL.eval()
32
+
33
+ classes = [
34
+ "zero",
35
+ "one",
36
+ "two",
37
+ "three",
38
+ "four",
39
+ "five",
40
+ "six",
41
+ "seven",
42
+ "eight",
43
+ "nine",
44
+ ]
45
+ labels_value = {label: 0.0 for label in classes}
46
+
47
+ accelerator_dtypes = ["float8", "bfloat16", "float32"]
48
+ dtype_map = {
49
+ "float8": Float8,
50
+ "bfloat16": BF16,
51
+ "float32": BF16, # TODO: use Float32, but not right now because its slow
52
+ }
53
+
54
+
55
+ mult_map = {
56
+ "IEEE 754": float_multiplier,
57
+ "l-mul": lmul_simple,
58
+ }
59
+
60
+
61
+ # ------------ Event Listener Functions ------------ #
62
+
63
+
64
+ def filter_activation_types(weight_type: str, activation_type: str):
65
+ if weight_type == "float8":
66
+ return gr.update(choices=accelerator_dtypes)
67
+ elif weight_type == "bfloat16":
68
+ if activation_type == "float8":
69
+ activation_type = "bfloat16"
70
+ return gr.update(value=activation_type, choices=["bfloat16", "float32"])
71
+ elif weight_type == "float32":
72
+ if activation_type != "float32":
73
+ activation_type = "float32"
74
+ return gr.update(value=activation_type, choices=["float32"])
75
+
76
+
77
+ def warn_w8a8(weight_type: str, activation_type: str):
78
+ if weight_type == "float8" and activation_type == "float8":
79
+ gr.Warning(
80
+ "W8A8 has poor performance without quantization, which is not yet supported in simulation. Theoretical results are still calculated for FP8 hardware",
81
+ duration=5,
82
+ )
83
+
84
+
85
+ def image_to_tensor(sketchpad: EditorValue):
86
+ image = sketchpad["composite"]
87
+ image = image.resize((28, 28), Image.Resampling.LANCZOS) # type: ignore
88
+ img_array = np.transpose(np.array(image), (2, 0, 1))[-1]
89
+
90
+ # Preprocessing: convert image to tensor and normalize
91
+ transform = transforms.Compose(
92
+ [
93
+ transforms.ToTensor(),
94
+ transforms.Normalize((0.1307,), (0.3081,)),
95
+ ]
96
+ )
97
+ tensor_image = transform(img_array)
98
+ return tensor_image
99
+
100
+
101
+ def calculate_stats(
102
+ activation_type: Literal["float8", "bfloat16", "float32"],
103
+ weight_type: Literal["float8", "bfloat16", "float32"],
104
+ systolic_array_size: int,
105
+ num_accelerator_cores: int,
106
+ fast_internals: Literal["Fast", "Efficient"],
107
+ pipeline_level: Literal["None", "Low", "Full"],
108
+ process_node_size: Literal["7nm", "45nm", "130nm"],
109
+ ):
110
+ """
111
+ Calculate hardware statistics for both lmul and standard IEEE multiplier configurations.
112
+
113
+ Args:
114
+ activation_type: Type of activations
115
+ weight_type: Type of weights
116
+ systolic_array_size: Size of the systolic array
117
+ num_accelerator_cores: Number of accelerator cores
118
+ fast_internals: Whether to use fast or efficient components
119
+ pipeline_level: Level of pipelining
120
+ process_node_size: Process node size (ignored for now)
121
+
122
+ Returns:
123
+ Tuple of (lmul_metrics, ieee_metrics, comparison_metrics) dictionaries
124
+ """
125
+ stat_map = {
126
+ "float8": "fp8",
127
+ "bfloat16": "bf16",
128
+ "float32": "fp32",
129
+ "Fast": True,
130
+ "Efficient": False,
131
+ "None": 0,
132
+ "Low": 1,
133
+ "None": "combinational",
134
+ "Low": "combinational",
135
+ "Full": "pipelined",
136
+ "7nm": 7,
137
+ "45nm": 45,
138
+ "130nm": 130,
139
+ }
140
+ # Calculate hardware stats using the functions from hardware_stats.py
141
+ lmul_metrics, ieee_metrics = calculate_hardware_stats(
142
+ DF,
143
+ activation_type,
144
+ weight_type,
145
+ systolic_array_size,
146
+ num_accelerator_cores,
147
+ fast_internals,
148
+ pipeline_level,
149
+ process_node_size,
150
+ )
151
+
152
+ # comparison_metrics = calculate_comparison_metrics(lmul_metrics, ieee_metrics)
153
+
154
+ # Format the metrics for display in the Gradio UI
155
+ lmul_html = "<div style='text-align: left;'>"
156
+ for key, value in lmul_metrics.items():
157
+ lmul_html += f"<p><b>{key}:</b> {value}</p>"
158
+ lmul_html += "</div>"
159
+
160
+ ieee_html = "<div style='text-align: left;'>"
161
+ for key, value in ieee_metrics.items():
162
+ ieee_html += f"<p><b>{key}:</b> {value}</p>"
163
+ ieee_html += "</div>"
164
+
165
+ # comparison_html = "<div style='text-align: left;'>"
166
+ # comparison_html += "<h3>Comparison (lmul vs IEEE)</h3>"
167
+ # for key, value in comparison_metrics.items():
168
+ # comparison_html += f"<p><b>{key}:</b> {value}</p>"
169
+ # comparison_html += "</div>"
170
+
171
+ return (
172
+ lmul_html,
173
+ ieee_html,
174
+ # comparison_html,
175
+ )
176
+
177
+
178
+ def predict_lmul(
179
+ sketchpad: EditorValue,
180
+ weight: str,
181
+ activation: str,
182
+ gr_progress=gr.Progress(track_tqdm=True),
183
+ ):
184
+ if weight == "float8" and activation == "float8":
185
+ activation = "bfloat16"
186
+ config = CompiledAcceleratorConfig(
187
+ array_size=8,
188
+ activation_type=dtype_map[activation],
189
+ weight_type=dtype_map[weight],
190
+ multiplier=lmul_simple,
191
+ )
192
+ sim = CompiledAcceleratorSimulator(config, MODEL)
193
+
194
+ x = image_to_tensor(sketchpad).detach().numpy().flatten()
195
+ probabilities = sim.predict(x)
196
+ return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
197
+
198
+
199
+ def predict_ieee(
200
+ sketchpad: EditorValue,
201
+ weight: str,
202
+ activation: str,
203
+ gr_progress=gr.Progress(track_tqdm=True),
204
+ ):
205
+ if weight == "float8" and activation == "float8":
206
+ activation = "bfloat16"
207
+ config = CompiledAcceleratorConfig(
208
+ array_size=8,
209
+ activation_type=dtype_map[activation],
210
+ weight_type=dtype_map[weight],
211
+ multiplier=float_multiplier,
212
+ )
213
+ simulator = CompiledAcceleratorSimulator(config, MODEL)
214
+
215
+ x = image_to_tensor(sketchpad).detach().numpy().flatten()
216
+ probabilities = simulator.predict(x)
217
+ return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
218
+
219
+
220
+ # ------------ Blocks UI Layout ------------ #
221
+
222
+
223
+ def create_app():
224
+ with gr.Blocks(fill_height=False, fill_width=False, title=__file__) as demo:
225
+
226
+ gr.Markdown("## Draw a digit to see the model's prediction")
227
+ with gr.Row(equal_height=False):
228
+ with gr.Column(scale=3):
229
+ canvas_size = (400, 400)
230
+ sketchpad = gr.Sketchpad(
231
+ # label="Draw a digit",
232
+ type="pil", # Changed to PIL
233
+ transforms=(),
234
+ layers=False,
235
+ canvas_size=canvas_size,
236
+ # scale=2,
237
+ container=False,
238
+ )
239
+ predict_btn = gr.Button(
240
+ "Run Hardware Simulation",
241
+ variant="primary",
242
+ )
243
+
244
+ # with gr.Accordion("Accelerator Configuration", open=True):
245
+ with gr.Group():
246
+ with gr.Row(): # Weight and activation types
247
+ weight_type_component = gr.Radio(
248
+ label="Weights d-type",
249
+ choices=accelerator_dtypes,
250
+ value="float8",
251
+ interactive=True,
252
+ )
253
+ activation_type_component = gr.Radio(
254
+ label="Activations d-type",
255
+ choices=accelerator_dtypes,
256
+ value="bfloat16",
257
+ interactive=True,
258
+ )
259
+ # Prevent w8a8 from being selected, or any other combination where act < weight
260
+ weight_type_component.select(
261
+ fn=filter_activation_types,
262
+ inputs=[weight_type_component, activation_type_component],
263
+ outputs=activation_type_component,
264
+ )
265
+ gr.on(
266
+ triggers=[
267
+ weight_type_component.select,
268
+ activation_type_component.select,
269
+ ],
270
+ fn=warn_w8a8,
271
+ inputs=[weight_type_component, activation_type_component],
272
+ )
273
+ with gr.Row():
274
+ systolic_array_size_component = gr.Slider(
275
+ label="Systolic Array Size",
276
+ info="Dimensions of the matrix acceleration unit",
277
+ minimum=4,
278
+ maximum=512,
279
+ step=1,
280
+ value=16,
281
+ interactive=True,
282
+ )
283
+ num_accelerator_cores_component = gr.Number(
284
+ label="Number of Accelerator Cores",
285
+ info="Total number of accelerator units per chip",
286
+ minimum=1,
287
+ maximum=1024,
288
+ step=1,
289
+ value=1,
290
+ interactive=True,
291
+ )
292
+ with gr.Row(equal_height=True):
293
+ fast_internals_component = gr.Dropdown(
294
+ label="Internal Component Type",
295
+ info="Configure the lowest level hardware units to use a faster or more efficient design.",
296
+ choices=["Fast", "Efficient"],
297
+ value="Fast",
298
+ interactive=True,
299
+ filterable=False,
300
+ )
301
+ pipeline_level_component = gr.Dropdown(
302
+ label="Pipeline Level",
303
+ info="Configure the pipeline level of processing elements within the accelerator. Low uses a single register between multipliers and adders. Full uses pipelined individual components.",
304
+ choices=["None", "Low", "Full"],
305
+ value="Full",
306
+ interactive=True,
307
+ filterable=False,
308
+ )
309
+ process_node_size_component = gr.Radio(
310
+ label="Process Node Size",
311
+ info="Configure the process node size of the hardware units. Smaller nodes are faster and use less area.",
312
+ choices=["7nm", "45nm", "130nm"],
313
+ value="45nm",
314
+ interactive=False,
315
+ )
316
+
317
+ with gr.Column(scale=2):
318
+ lmul_predictions = gr.Label(
319
+ label="l-mul Simulator Predictions",
320
+ value=labels_value,
321
+ min_width=100,
322
+ )
323
+
324
+ lmul_html = gr.HTML(
325
+ label="L-mul Hardware Stats",
326
+ show_label=True,
327
+ container=True,
328
+ )
329
+
330
+ with gr.Column(scale=2):
331
+ ieee_predictions = gr.Label(
332
+ label="Hardware Simulator Predictions",
333
+ value=labels_value,
334
+ min_width=100,
335
+ )
336
+ ieee_html = gr.HTML(
337
+ label="IEEE Multiplier Hardware Stats",
338
+ show_label=True,
339
+ container=True,
340
+ )
341
+
342
+ # ------------ Event Listeners ------------ #
343
+
344
+ predict_btn.click(
345
+ fn=predict_ieee,
346
+ inputs=[sketchpad, weight_type_component, activation_type_component],
347
+ outputs=ieee_predictions,
348
+ )
349
+
350
+ # TODO: implement simulator_predict
351
+ predict_btn.click(
352
+ fn=predict_lmul,
353
+ inputs=[sketchpad, weight_type_component, activation_type_component],
354
+ outputs=lmul_predictions,
355
+ )
356
+
357
+ gr.on(
358
+ triggers=[
359
+ demo.load,
360
+ activation_type_component.change,
361
+ weight_type_component.change,
362
+ systolic_array_size_component.change,
363
+ num_accelerator_cores_component.change,
364
+ fast_internals_component.change,
365
+ pipeline_level_component.change,
366
+ process_node_size_component.change,
367
+ ],
368
+ fn=calculate_stats,
369
+ inputs=[
370
+ activation_type_component,
371
+ weight_type_component,
372
+ systolic_array_size_component,
373
+ num_accelerator_cores_component,
374
+ fast_internals_component,
375
+ pipeline_level_component,
376
+ process_node_size_component,
377
+ ],
378
+ outputs=[lmul_html, ieee_html],
379
+ show_progress="hidden",
380
+ )
381
+
382
+ return demo
383
+
384
+
385
+ if __name__ == "__main__":
386
+ demo = create_app()
387
+ demo.queue()
388
+ demo.launch(share=False)
hardware_accelerators/compile.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ import multiprocessing
4
+ import pyrtl
5
+ from tqdm import tqdm
6
+ from functools import partial
7
+
8
+ from .simulation import CompiledAcceleratorSimulator
9
+
10
+ from .rtllib import float_multiplier, lmul_simple
11
+ from .rtllib.accelerator import CompiledAcceleratorConfig
12
+ from .dtypes import BaseFloat, Float32, Float16, BF16, Float8
13
+ from typing import Iterator, Type, List, Callable
14
+ from itertools import product
15
+
16
+
17
+ def generate_accelerator_configs(
18
+ array_size: int = 8,
19
+ dtypes: List[Type[BaseFloat]] | None = None,
20
+ multipliers: List[Callable] | None = None,
21
+ **kwargs,
22
+ ) -> Iterator[CompiledAcceleratorConfig]:
23
+ """
24
+ Generate all valid CompiledAcceleratorConfig combinations.
25
+
26
+ Args:
27
+ array_size: Size of the systolic array
28
+ dtypes: List of data types to consider. Defaults to [Float8, BF16, FP16, FP32]
29
+ multipliers: List of multiplier functions. Defaults to [float_multiplier, lmul]
30
+
31
+ Yields:
32
+ Valid CompiledAcceleratorConfig objects
33
+
34
+ Restrictions:
35
+ 1. The activation_type must be greater than or equal to the weight_type in terms of bitwidth.
36
+ 2. 16-bit float types (BF16, FP16) should not be combined with each other.
37
+ They should only pair with themselves or with FP32.
38
+ """
39
+ if dtypes is None:
40
+ dtypes = [Float8, BF16, Float32]
41
+
42
+ if multipliers is None:
43
+ multipliers = [float_multiplier, lmul_simple]
44
+
45
+ # Sort dtypes by bitwidth for easier comparison
46
+ dtype_bitwidths = {dtype: dtype.bitwidth() for dtype in dtypes}
47
+ sorted_dtypes = sorted(dtypes, key=lambda d: dtype_bitwidths[d])
48
+
49
+ # Identify 16-bit float types
50
+ bit16_float_types = [dtype for dtype in dtypes if dtype_bitwidths[dtype] == 16]
51
+
52
+ # Generate all combinations
53
+ for multiplier in multipliers:
54
+ for weight_type in sorted_dtypes:
55
+ # Find valid activation types based on bitwidth
56
+ valid_activation_types = [
57
+ dtype
58
+ for dtype in sorted_dtypes
59
+ if dtype_bitwidths[dtype] >= dtype_bitwidths[weight_type]
60
+ ]
61
+
62
+ for activation_type in valid_activation_types:
63
+ # Skip invalid combinations of 16-bit float types
64
+ if (
65
+ weight_type in bit16_float_types
66
+ and activation_type in bit16_float_types
67
+ and weight_type != activation_type
68
+ ):
69
+ continue
70
+
71
+ yield CompiledAcceleratorConfig(
72
+ array_size=array_size,
73
+ activation_type=activation_type,
74
+ weight_type=weight_type,
75
+ multiplier=multiplier,
76
+ **kwargs,
77
+ )
78
+
79
+
80
+ def compile_and_save_simulator(config):
81
+ """Compile and save a simulator for a given configuration.
82
+
83
+ Args:
84
+ config: The CompiledAcceleratorConfig to use
85
+
86
+ Returns:
87
+ Tuple of (config, success, time_taken)
88
+ """
89
+ start_time = time.time()
90
+
91
+ try:
92
+ # Create the simulator
93
+ with pyrtl.temp_working_block():
94
+ CompiledAcceleratorSimulator(config)
95
+
96
+ end_time = time.time()
97
+ return (config, True, end_time - start_time)
98
+
99
+ except Exception as e:
100
+ end_time = time.time()
101
+ print(f"Error compiling {config}: {str(e)}")
102
+ return (config, False, end_time - start_time)
103
+
104
+
105
+ def compile_all_simulators(configs=None, max_workers=None):
106
+ """Compile and save simulators for all configurations using multiprocessing.
107
+
108
+ Args:
109
+ configs: List of configurations to compile. If None, generates all valid configs.
110
+ base_dir: Base directory to save simulations
111
+ max_workers: Maximum number of worker processes. If None, uses CPU count.
112
+
113
+ Returns:
114
+ List of results (config, success, time_taken)
115
+ """
116
+ if configs is None:
117
+ configs = list(generate_accelerator_configs())
118
+
119
+ if max_workers is None:
120
+ max_workers = multiprocessing.cpu_count()
121
+
122
+ print(f"Compiling {len(configs)} configurations using {max_workers} workers")
123
+
124
+ # Create a partial function with the base_dir parameter
125
+ compile_func = partial(compile_and_save_simulator)
126
+
127
+ # Use multiprocessing to compile all configurations
128
+ with multiprocessing.Pool(processes=max_workers) as pool:
129
+ # Use tqdm to show progress
130
+ results = list(
131
+ tqdm(
132
+ pool.imap(compile_func, configs),
133
+ total=len(configs),
134
+ desc="Compiling simulators",
135
+ )
136
+ )
137
+
138
+ # Print summary
139
+ successful = [r for r in results if r[1]]
140
+ failed = [r for r in results if not r[1]]
141
+
142
+ print(f"\nCompilation complete:")
143
+ print(f" Total: {len(results)}")
144
+ print(f" Successful: {len(successful)}")
145
+ print(f" Failed: {len(failed)}")
146
+
147
+ if successful:
148
+ avg_time = sum(r[2] for r in successful) / len(successful)
149
+ print(f" Average compilation time: {avg_time:.2f} seconds")
150
+
151
+ return results
152
+
153
+
154
+ if __name__ == "__main__":
155
+ # Generate all valid configurations
156
+ all_configs = list(generate_accelerator_configs())
157
+ print(f"Generated {len(all_configs)} configs")
158
+
159
+ # Compile and save simulators for all configurations
160
+ results = compile_all_simulators(all_configs)
161
+
162
+ # Print details of failed compilations if any
163
+ failed = [r for r in results if not r[1]]
164
+ if failed:
165
+ print("\nFailed compilations:")
166
+ for config, _, _ in failed:
167
+ print(config.name)
hardware_accelerators/dtypes/__init__.py CHANGED
@@ -2,5 +2,7 @@
2
  from .base import BaseFloat
3
  from .bfloat16 import BF16
4
  from .float8 import Float8
 
 
5
 
6
- __all__ = ["BaseFloat", "Float8", "BF16"]
 
2
  from .base import BaseFloat
3
  from .bfloat16 import BF16
4
  from .float8 import Float8
5
+ from .float16 import Float16
6
+ from .float32 import Float32
7
 
8
+ __all__ = ["BaseFloat", "Float8", "BF16", "Float16", "Float32"]
hardware_accelerators/dtypes/base.py CHANGED
@@ -56,6 +56,11 @@ class BaseFloat(ABC):
56
  else:
57
  raise ValueError("Must provide one of: value, binary, or binint")
58
 
 
 
 
 
 
59
  @classmethod
60
  @abstractmethod
61
  def format_spec(cls) -> FormatSpec:
@@ -160,7 +165,6 @@ class BaseFloat(ABC):
160
 
161
  def _format_binary_string(self, binary=None) -> str:
162
  """Format binary string with dots for readability"""
163
- # Clean the input string first
164
  if binary is None:
165
  binary = self.binary
166
  clean_binary = "".join(c for c in binary if c in "01")
@@ -169,8 +173,13 @@ class BaseFloat(ABC):
169
 
170
  if self.bitwidth() == 8: # Float8
171
  return f"{clean_binary[0]}.{clean_binary[1:5]}.{clean_binary[5:]}"
172
- elif self.bitwidth() == 16: # BF16
173
- return clean_binary # BF16 doesn't use dot formatting
 
 
 
 
 
174
  else:
175
  return clean_binary
176
 
 
56
  else:
57
  raise ValueError("Must provide one of: value, binary, or binint")
58
 
59
+ @classmethod
60
+ @abstractmethod
61
+ def binary_max(cls) -> int:
62
+ pass
63
+
64
  @classmethod
65
  @abstractmethod
66
  def format_spec(cls) -> FormatSpec:
 
165
 
166
  def _format_binary_string(self, binary=None) -> str:
167
  """Format binary string with dots for readability"""
 
168
  if binary is None:
169
  binary = self.binary
170
  clean_binary = "".join(c for c in binary if c in "01")
 
173
 
174
  if self.bitwidth() == 8: # Float8
175
  return f"{clean_binary[0]}.{clean_binary[1:5]}.{clean_binary[5:]}"
176
+ elif self.bitwidth() == 32: # Float32
177
+ return f"{clean_binary[0]}.{clean_binary[1:9]}.{clean_binary[9:]}"
178
+ elif self.bitwidth() == 16:
179
+ if self.__class__.__name__ == "Float16": # Float16
180
+ return f"{clean_binary[0]}.{clean_binary[1:6]}.{clean_binary[6:]}"
181
+ else: # BF16
182
+ return clean_binary
183
  else:
184
  return clean_binary
185
 
hardware_accelerators/dtypes/bfloat16.py CHANGED
@@ -29,6 +29,10 @@ class BF16(BaseFloat):
29
  min_subnormal=2**-126 * (1 / 128),
30
  )
31
 
 
 
 
 
32
  def _float32_to_bf16_parts(self, f32: float) -> Tuple[int, int, int]:
33
  """Convert float32 to BF16 parts (sign, exponent, mantissa)"""
34
  # Get binary representation of float32
 
29
  min_subnormal=2**-126 * (1 / 128),
30
  )
31
 
32
+ @classmethod
33
+ def binary_max(cls) -> int:
34
+ return 0b0111111101111111
35
+
36
  def _float32_to_bf16_parts(self, f32: float) -> Tuple[int, int, int]:
37
  """Convert float32 to BF16 parts (sign, exponent, mantissa)"""
38
  # Get binary representation of float32
hardware_accelerators/dtypes/float16.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseFloat, FormatSpec
2
+
3
+
4
+ class Float16(BaseFloat):
5
+ """
6
+ 16-bit floating point number with IEEE 754 half-precision format
7
+ - 1 sign bit
8
+ - 5 exponent bits (bias 15)
9
+ - 10 mantissa bits
10
+ """
11
+
12
+ @classmethod
13
+ def format_spec(cls) -> FormatSpec:
14
+ return FormatSpec(
15
+ total_bits=16,
16
+ exponent_bits=5,
17
+ mantissa_bits=10,
18
+ bias=15,
19
+ max_normal=65504.0, # from 0.11110.1111111111
20
+ min_normal=2**-14, # from 0.00001.0000000000
21
+ max_subnormal=2**-14 * (1023 / 1024), # from 0.00000.1111111111
22
+ min_subnormal=2**-24, # from 0.00000.0000000001
23
+ )
24
+
25
+ @classmethod
26
+ def binary_max(cls) -> int:
27
+ return 0b0111101111111111
28
+
29
+ def _decimal_to_binary(self, num: float) -> str:
30
+ """Convert decimal number to binary string in IEEE 754 format"""
31
+ if num == 0:
32
+ return "0.00000.0000000000"
33
+
34
+ # Extract sign bit
35
+ sign = "1" if num < 0 else "0"
36
+ num = abs(num)
37
+
38
+ # Handle NaN
39
+ if num != num: # Python's way to check for NaN
40
+ return sign + ".11111.1111111111"
41
+
42
+ # Clamp to max value if overflow
43
+ if num > self.max_normal():
44
+ return "0.11110.1111111111" if sign == "0" else "1.11110.1111111111"
45
+
46
+ # Find exponent and normalized mantissa
47
+ exp = 0
48
+ temp = num
49
+
50
+ # Handle normal numbers
51
+ while temp >= 2 and exp < 31:
52
+ temp /= 2
53
+ exp += 1
54
+ while temp < 1 and exp > -14:
55
+ temp *= 2
56
+ exp -= 1
57
+
58
+ # Handle subnormal numbers
59
+ if exp <= -14:
60
+ # Shift mantissa right and adjust
61
+ shift = -14 - exp
62
+ temp /= 2**shift
63
+ exp = -14
64
+
65
+ # Calculate biased exponent
66
+ if temp < 1: # Subnormal
67
+ biased_exp = "00000"
68
+ else: # Normal
69
+ biased_exp = format(exp + self.bias(), "05b")
70
+
71
+ # Calculate mantissa bits
72
+ if temp < 1: # Subnormal
73
+ mantissa_value = int(temp * (2 ** self.mantissa_bits()))
74
+ else: # Normal
75
+ mantissa_value = int((temp - 1) * (2 ** self.mantissa_bits()))
76
+
77
+ mantissa = format(mantissa_value, f"0{self.mantissa_bits()}b")
78
+
79
+ return f"{sign}.{biased_exp}.{mantissa}"
80
+
81
+ def _binary_to_decimal(self, binary: str) -> float:
82
+ """Convert binary string in IEEE 754 format to decimal"""
83
+ # Clean up binary string
84
+ binary = "".join(c for c in binary if c in "01")
85
+
86
+ # Extract components
87
+ sign = -1 if binary[0] == "1" else 1
88
+ exp = binary[1:6]
89
+ mantissa = binary[6:]
90
+
91
+ # Handle special cases
92
+ if exp == "11111" and mantissa == "1111111111": # NaN representation
93
+ return float("nan")
94
+ if exp == "00000" and mantissa == "0000000000":
95
+ return 0.0
96
+
97
+ # Convert biased exponent
98
+ biased_exp = int(exp, 2)
99
+
100
+ if biased_exp == 0: # Subnormal number
101
+ actual_exp = -14
102
+ mantissa_value = int(mantissa, 2) / (2 ** self.mantissa_bits())
103
+ return sign * (2**actual_exp) * mantissa_value
104
+ else: # Normal number
105
+ actual_exp = biased_exp - self.bias()
106
+ mantissa_value = 1 + int(mantissa, 2) / (2 ** self.mantissa_bits())
107
+ return sign * (2**actual_exp) * mantissa_value
108
+
109
+ @classmethod
110
+ def from_bits(cls, bits: int) -> "Float16":
111
+ """Create Float16 from 16-bit integer"""
112
+ return cls(binint=bits)
113
+
114
+ @classmethod
115
+ def nan(cls) -> "Float16":
116
+ """Create NaN value"""
117
+ return cls(binary="1.11111.1111111111")
118
+
119
+ @classmethod
120
+ def max_value(cls) -> "Float16":
121
+ """Create maximum representable value"""
122
+ return cls(binary="0.11110.1111111111")
123
+
124
+ @classmethod
125
+ def min_value(cls) -> "Float16":
126
+ """Create minimum representable normal value"""
127
+ return cls(binary="0.00001.0000000000")
128
+
129
+ @classmethod
130
+ def min_subnormal(cls) -> "Float16":
131
+ """Create minimum representable subnormal value"""
132
+ return cls(binary="0.00000.0000000001")
133
+
134
+ def detailed_breakdown(self) -> dict:
135
+ """Provide detailed breakdown of the Float16 number components"""
136
+ binary = "".join(c for c in self.binary if c in "01")
137
+ sign_bit = int(binary[0])
138
+ exp_bits = binary[1:6]
139
+ mantissa_bits = binary[6:]
140
+
141
+ exp_val = int(exp_bits, 2)
142
+ mantissa_val = int(mantissa_bits, 2)
143
+
144
+ is_normal = exp_val != 0 and exp_val != 31
145
+ is_subnormal = exp_val == 0 and mantissa_val != 0
146
+ is_zero = exp_val == 0 and mantissa_val == 0
147
+ is_nan = (
148
+ exp_val == 31 and mantissa_val == 1023
149
+ ) # Only s.11111.1111111111 is NaN
150
+
151
+ return {
152
+ "binary": self.binary,
153
+ "sign": sign_bit,
154
+ "exponent_bits": exp_bits,
155
+ "exponent_value": (exp_val - self.bias() if exp_val != 0 else "subnormal"),
156
+ "mantissa_bits": mantissa_bits,
157
+ "mantissa_value": mantissa_val,
158
+ "decimal_approx": self.decimal_approx,
159
+ "original_value": self.original_value,
160
+ "is_normal": is_normal,
161
+ "is_subnormal": is_subnormal,
162
+ "is_zero": is_zero,
163
+ "is_nan": is_nan,
164
+ "normalized_value": (
165
+ (1 + mantissa_val / 1024) if is_normal else (mantissa_val / 1024)
166
+ ),
167
+ }
hardware_accelerators/dtypes/float32.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from .base import BaseFloat, FormatSpec
2
+
3
+
4
+ class Float32(BaseFloat):
5
+ """
6
+ 32-bit floating point number with IEEE 754 single-precision format
7
+ - 1 sign bit
8
+ - 8 exponent bits (bias 127)
9
+ - 23 mantissa bits
10
+ """
11
+
12
+ @classmethod
13
+ def format_spec(cls) -> FormatSpec:
14
+ return FormatSpec(
15
+ total_bits=32,
16
+ exponent_bits=8,
17
+ mantissa_bits=23,
18
+ bias=127,
19
+ max_normal=3.4028235e38, # from 0.11111110.11111111111111111111111
20
+ min_normal=2**-126, # from 0.00000001.00000000000000000000000
21
+ max_subnormal=2**-126
22
+ * (8388607 / 8388608), # from 0.00000000.11111111111111111111111
23
+ min_subnormal=2**-149, # from 0.00000000.00000000000000000000001
24
+ )
25
+
26
+ @classmethod
27
+ def binary_max(cls) -> int:
28
+ return 0b01111111011111111111111111111111
29
+
30
+ def _decimal_to_binary(self, num: float) -> str:
31
+ """Convert decimal number to binary string in IEEE 754 format"""
32
+ if num == 0:
33
+ return "0.00000000.00000000000000000000000"
34
+
35
+ # Extract sign bit
36
+ sign = "1" if num < 0 else "0"
37
+ num = abs(num)
38
+
39
+ # Handle NaN
40
+ if num != num: # Python's way to check for NaN
41
+ return sign + ".11111111.11111111111111111111111"
42
+
43
+ # Clamp to max value if overflow
44
+ if num > self.max_normal():
45
+ return (
46
+ "0.11111110.11111111111111111111111"
47
+ if sign == "0"
48
+ else "1.11111110.11111111111111111111111"
49
+ )
50
+
51
+ # Find exponent and normalized mantissa
52
+ exp = 0
53
+ temp = num
54
+
55
+ # Handle normal numbers
56
+ while temp >= 2 and exp < 255:
57
+ temp /= 2
58
+ exp += 1
59
+ while temp < 1 and exp > -126:
60
+ temp *= 2
61
+ exp -= 1
62
+
63
+ # Handle subnormal numbers
64
+ if exp <= -126:
65
+ # Shift mantissa right and adjust
66
+ shift = -126 - exp
67
+ temp /= 2**shift
68
+ exp = -126
69
+
70
+ # Calculate biased exponent
71
+ if temp < 1: # Subnormal
72
+ biased_exp = "00000000"
73
+ else: # Normal
74
+ biased_exp = format(exp + self.bias(), "08b")
75
+
76
+ # Calculate mantissa bits
77
+ if temp < 1: # Subnormal
78
+ mantissa_value = int(temp * (2 ** self.mantissa_bits()))
79
+ else: # Normal
80
+ mantissa_value = int((temp - 1) * (2 ** self.mantissa_bits()))
81
+
82
+ mantissa = format(mantissa_value, f"0{self.mantissa_bits()}b")
83
+
84
+ return f"{sign}.{biased_exp}.{mantissa}"
85
+
86
+ def _binary_to_decimal(self, binary: str) -> float:
87
+ """Convert binary string in IEEE 754 format to decimal"""
88
+ # Clean up binary string
89
+ binary = "".join(c for c in binary if c in "01")
90
+
91
+ # Extract components
92
+ sign = -1 if binary[0] == "1" else 1
93
+ exp = binary[1:9]
94
+ mantissa = binary[9:]
95
+
96
+ # Handle special cases
97
+ if (
98
+ exp == "11111111" and mantissa == "11111111111111111111111"
99
+ ): # NaN representation
100
+ return float("nan")
101
+ if exp == "00000000" and mantissa == "00000000000000000000000":
102
+ return 0.0
103
+
104
+ # Convert biased exponent
105
+ biased_exp = int(exp, 2)
106
+
107
+ if biased_exp == 0: # Subnormal number
108
+ actual_exp = -126
109
+ mantissa_value = int(mantissa, 2) / (2 ** self.mantissa_bits())
110
+ return sign * (2**actual_exp) * mantissa_value
111
+ else: # Normal number
112
+ actual_exp = biased_exp - self.bias()
113
+ mantissa_value = 1 + int(mantissa, 2) / (2 ** self.mantissa_bits())
114
+ return sign * (2**actual_exp) * mantissa_value
115
+
116
+ @classmethod
117
+ def from_bits(cls, bits: int) -> "Float32":
118
+ """Create Float32 from 32-bit integer"""
119
+ return cls(binint=bits)
120
+
121
+ @classmethod
122
+ def nan(cls) -> "Float32":
123
+ """Create NaN value"""
124
+ return cls(binary="1.11111111.11111111111111111111111")
125
+
126
+ @classmethod
127
+ def max_value(cls) -> "Float32":
128
+ """Create maximum representable value"""
129
+ return cls(binary="0.11111110.11111111111111111111111")
130
+
131
+ @classmethod
132
+ def min_value(cls) -> "Float32":
133
+ """Create minimum representable normal value"""
134
+ return cls(binary="0.00000001.00000000000000000000000")
135
+
136
+ @classmethod
137
+ def min_subnormal(cls) -> "Float32":
138
+ """Create minimum representable subnormal value"""
139
+ return cls(binary="0.00000000.00000000000000000000001")
140
+
141
+ def detailed_breakdown(self) -> dict:
142
+ """Provide detailed breakdown of the Float32 number components"""
143
+ binary = "".join(c for c in self.binary if c in "01")
144
+ sign_bit = int(binary[0])
145
+ exp_bits = binary[1:9]
146
+ mantissa_bits = binary[9:]
147
+
148
+ exp_val = int(exp_bits, 2)
149
+ mantissa_val = int(mantissa_bits, 2)
150
+
151
+ is_normal = exp_val != 0 and exp_val != 255
152
+ is_subnormal = exp_val == 0 and mantissa_val != 0
153
+ is_zero = exp_val == 0 and mantissa_val == 0
154
+ is_nan = (
155
+ exp_val == 255 and mantissa_val == 8388607
156
+ ) # Only s.11111111.11111111111111111111111 is NaN
157
+
158
+ return {
159
+ "binary": self.binary,
160
+ "sign": sign_bit,
161
+ "exponent_bits": exp_bits,
162
+ "exponent_value": (exp_val - self.bias() if exp_val != 0 else "subnormal"),
163
+ "mantissa_bits": mantissa_bits,
164
+ "mantissa_value": mantissa_val,
165
+ "decimal_approx": self.decimal_approx,
166
+ "original_value": self.original_value,
167
+ "is_normal": is_normal,
168
+ "is_subnormal": is_subnormal,
169
+ "is_zero": is_zero,
170
+ "is_nan": is_nan,
171
+ "normalized_value": (
172
+ (1 + mantissa_val / 8388608) if is_normal else (mantissa_val / 8388608)
173
+ ),
174
+ }
hardware_accelerators/dtypes/float8.py CHANGED
@@ -23,6 +23,10 @@ class Float8(BaseFloat):
23
  min_subnormal=2**-6 * (1 / 8), # from 0.0000.001
24
  )
25
 
 
 
 
 
26
  def _decimal_to_binary(self, num: float) -> str:
27
  """Convert decimal number to binary string in E4M3 format"""
28
  if num == 0:
 
23
  min_subnormal=2**-6 * (1 / 8), # from 0.0000.001
24
  )
25
 
26
+ @classmethod
27
+ def binary_max(cls) -> int:
28
+ return 0b01111110
29
+
30
  def _decimal_to_binary(self, num: float) -> str:
31
  """Convert decimal number to binary string in E4M3 format"""
32
  if num == 0:
hardware_accelerators/nn/lmul.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchvision
5
+ from torch.utils.data import DataLoader
6
+ from torchvision import datasets, transforms
7
+ import math
8
+ import time
9
+
10
+
11
+ # Custom approximate matrix multiplication using lmul
12
+ def lmul_matmul(A: torch.Tensor, B: torch.Tensor, dtype=torch.float32):
13
+ """
14
+ Approximate matrix multiplication between A (m x n) and B (n x p)
15
+ using bitwise operations to mimic multiplication.
16
+ """
17
+ if dtype == torch.float32:
18
+ # reinterpret bits as uint32 then convert to int64 for arithmetic
19
+ A_int = A.contiguous().view(torch.uint32).to(torch.int64)
20
+ B_int = B.contiguous().view(torch.uint32).to(torch.int64)
21
+ offset = 1064828928 # offset for float32
22
+ elif dtype == torch.bfloat16:
23
+ A_int = A.contiguous().view(torch.uint16).to(torch.int64)
24
+ B_int = B.contiguous().view(torch.uint16).to(torch.int64)
25
+ offset = 16248 # offset for bfloat16
26
+ else:
27
+ raise ValueError("Unsupported dtype")
28
+
29
+ # A is (m, n) and B is (n, p).
30
+ # Expand dims so that:
31
+ # A_int becomes (m, n, 1) and B_int becomes (1, n, p)
32
+ prod_int = A_int.unsqueeze(2) + B_int.unsqueeze(0) - offset # shape: (m, n, p)
33
+
34
+ # Convert the integer result back to floating point.
35
+ if dtype == torch.float32:
36
+ prod = prod_int.to(torch.uint32).view(torch.float32)
37
+ else:
38
+ prod = prod_int.to(torch.uint16).view(torch.bfloat16)
39
+
40
+ # Sum over the inner dimension to complete the dot product.
41
+ return prod.sum(dim=1)
42
+
43
+
44
+ # Custom linear layer that uses lmul-based matrix multiplication
45
+ class LmulLinear(nn.Module):
46
+ def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
47
+ super(LmulLinear, self).__init__()
48
+ self.in_features = in_features
49
+ self.out_features = out_features
50
+ self.dtype = dtype
51
+ self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
52
+ if bias:
53
+ self.bias = nn.Parameter(torch.Tensor(out_features))
54
+ else:
55
+ self.register_parameter("bias", None)
56
+ self.reset_parameters()
57
+
58
+ def reset_parameters(self):
59
+ # Initialize weights similarly to nn.Linear.
60
+ nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
61
+ if self.bias is not None:
62
+ fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
63
+ bound = 1 / math.sqrt(fan_in)
64
+ nn.init.uniform_(self.bias, -bound, bound)
65
+
66
+ def forward(self, input):
67
+ # Compute the approximate matrix multiply:
68
+ # Note: input shape is (batch, in_features)
69
+ # weight.T shape is (in_features, out_features)
70
+ out = lmul_matmul(input, self.weight.t(), self.dtype)
71
+ if self.bias is not None:
72
+ out = out + self.bias # add bias as usual
73
+ return out
74
+
75
+
76
+ # MLP model using our custom lmul-based linear layers
77
+ class LmulMLP(nn.Module):
78
+ def __init__(self, input_size, hidden_size, num_classes, dtype=torch.float32):
79
+ super(LmulMLP, self).__init__()
80
+ self.flatten = nn.Flatten()
81
+ self.fc1 = LmulLinear(input_size, hidden_size, bias=True, dtype=dtype)
82
+ self.relu = nn.ReLU()
83
+ self.fc2 = LmulLinear(hidden_size, num_classes, bias=True, dtype=dtype)
84
+
85
+ def forward(self, x):
86
+ x = self.flatten(x)
87
+ x = self.fc1(x)
88
+ x = self.relu(x)
89
+ x = self.fc2(x)
90
+ return x
91
+
92
+
93
+ # Setup: use float32 for this example.
94
+ dtype = torch.float32
95
+
96
+ # Instantiate the model.
97
+ # For MNIST: input size is 28x28 = 784, hidden layer of 128, output 10 classes.
98
+ model = LmulMLP(input_size=784, hidden_size=128, num_classes=10, dtype=dtype)
99
+ model.eval() # set model to evaluation mode
100
+
101
+ model.load_state_dict(torch.load("models/mlp_mnist_fp32.pth", weights_only=True))
102
+
103
+ # Prepare the MNIST test dataset.
104
+ transform = transforms.Compose(
105
+ [
106
+ transforms.ToTensor(),
107
+ transforms.Normalize((0.1307,), (0.3081,)),
108
+ ]
109
+ )
110
+ test_dataset = datasets.MNIST(
111
+ root="./data", train=False, transform=transform, download=True
112
+ )
113
+ test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
114
+
115
+ # Run inference on the test dataset and measure accuracy.
116
+ correct = 0
117
+ total = 0
118
+ start_time = time.time()
119
+
120
+ with torch.no_grad():
121
+ for images, labels in test_loader:
122
+ # Ensure images are in the right dtype
123
+ images = images.to(dtype)
124
+ outputs = model(images)
125
+ # Compute predictions
126
+ _, predicted = torch.max(outputs, 1)
127
+ total += labels.size(0)
128
+ correct += (predicted.cpu() == labels).sum().item()
129
+
130
+ end_time = time.time()
131
+ accuracy = correct / total * 100
132
+ inference_time = end_time - start_time
133
+
134
+ print(f"Test Accuracy: {accuracy:.2f}%")
135
+ print(f"Inference Time on Test Set: {inference_time:.2f} seconds")
hardware_accelerators/nn/precision.py ADDED
@@ -0,0 +1,264 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.optim as optim
5
+ from torchvision import datasets, transforms
6
+ from torch.utils.data import DataLoader
7
+ from tqdm.auto import tqdm
8
+
9
+ from .util import get_pytorch_device
10
+
11
+
12
+ # Define the MLP model (unchanged)
13
+ class MLP(nn.Module):
14
+ def __init__(self, input_size, hidden_size, num_classes):
15
+ super(MLP, self).__init__()
16
+ self.flatten = nn.Flatten()
17
+ self.fc1 = nn.Linear(input_size, hidden_size)
18
+ self.relu = nn.ReLU()
19
+ self.fc2 = nn.Linear(hidden_size, num_classes)
20
+
21
+ def forward(self, x):
22
+ x = self.flatten(x)
23
+ x = self.fc1(x)
24
+ x = self.relu(x)
25
+ out = self.fc2(x)
26
+ return out
27
+
28
+
29
+ # Helper function: adjust data to match the target dtype
30
+ def convert_input(data, precision):
31
+ if precision == "fp16":
32
+ return data.half()
33
+ elif precision == "bf16":
34
+ return data.to(torch.bfloat16)
35
+ elif precision == "fp8":
36
+ # Note: torch.float8_e4m3 is experimental and may not be available
37
+ return data.to(torch.float8_e4m3fn)
38
+ return data # fp32 (no conversion)
39
+
40
+
41
+ # Training for one epoch
42
+ def train_epoch(model, device, train_loader, optimizer, criterion, precision):
43
+ model.train()
44
+ running_loss = 0.0
45
+ progress_bar = tqdm(train_loader, desc="Training", leave=False)
46
+ for data, target in progress_bar:
47
+ # Convert inputs to the desired precision (targets remain integer)
48
+ data = convert_input(data, precision)
49
+ data, target = data.to(device), target.to(device)
50
+ optimizer.zero_grad()
51
+
52
+ # Forward pass
53
+ outputs = model(data)
54
+ loss = criterion(outputs, target)
55
+
56
+ # Check for NaN and skip problematic batches
57
+ if torch.isnan(loss):
58
+ print("NaN loss detected in batch, skipping...")
59
+ continue
60
+
61
+ # Backward and optimize with gradient clipping
62
+ loss.backward()
63
+
64
+ # Apply gradient clipping to prevent exploding gradients
65
+ torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
66
+
67
+ optimizer.step()
68
+ running_loss += loss.item()
69
+
70
+ if len(train_loader) > 0:
71
+ return running_loss / len(train_loader)
72
+ return 0.0
73
+
74
+
75
+ # Evaluation loop
76
+ def evaluate(model, device, test_loader, criterion, precision):
77
+ model.eval()
78
+ total_loss = 0.0
79
+ correct = 0
80
+ total = 0
81
+ with torch.no_grad():
82
+ for data, target in test_loader:
83
+ data = convert_input(data, precision)
84
+ data, target = data.to(device), target.to(device)
85
+ outputs = model(data)
86
+ loss = criterion(outputs, target)
87
+ total_loss += loss.item()
88
+ _, predicted = torch.max(outputs, 1)
89
+ total += target.size(0)
90
+ correct += (predicted == target).sum().item()
91
+ avg_loss = total_loss / len(test_loader)
92
+ accuracy = 100.0 * correct / total
93
+ return avg_loss, accuracy
94
+
95
+
96
+ # Main training function for a given precision variant
97
+ def train_model(
98
+ precision,
99
+ batch_size=32,
100
+ hidden_size=128,
101
+ num_epochs=5,
102
+ learning_rate=0.001,
103
+ optimizer_name="adam",
104
+ weight_decay=0,
105
+ eps=1e-4,
106
+ model_save_path=None,
107
+ ):
108
+ print(f"\nTraining in {precision.upper()} mode:")
109
+ device = get_pytorch_device()
110
+
111
+ # Data transformation: images are loaded as FP32 by default
112
+ transform = transforms.Compose(
113
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
114
+ )
115
+
116
+ train_dataset = datasets.MNIST(
117
+ root="./data", train=True, download=True, transform=transform
118
+ )
119
+ test_dataset = datasets.MNIST(
120
+ root="./data", train=False, download=True, transform=transform
121
+ )
122
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
123
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
124
+
125
+ # Hyperparameters
126
+ input_size = 28 * 28 # MNIST images are 28x28
127
+ num_classes = 10
128
+
129
+ # Create the model and send to device
130
+ model = MLP(input_size, hidden_size, num_classes).to(device)
131
+
132
+ # Convert the model to the target precision (natively)
133
+ if precision == "fp16":
134
+ model = model.to(torch.float16)
135
+ # Use a smaller learning rate for half precision if not explicitly specified
136
+ if learning_rate == 0.001: # If using the default value
137
+ learning_rate = 1e-4 # Lower learning rate for stability
138
+ elif precision == "bf16":
139
+ model = model.to(torch.bfloat16)
140
+ elif precision == "fp8":
141
+ # Ensure your PyTorch build/hardware supports float8_e4m3; otherwise, this will error.
142
+ model = model.to(torch.float8_e4m3fn)
143
+ # else, fp32 is already the default
144
+
145
+ # Select optimizer based on user input
146
+ if optimizer_name.lower() == "adam":
147
+ optimizer = optim.Adam(
148
+ model.parameters(), lr=learning_rate, eps=eps, weight_decay=weight_decay
149
+ )
150
+ elif optimizer_name.lower() == "sgd":
151
+ optimizer = optim.SGD(
152
+ model.parameters(),
153
+ lr=learning_rate,
154
+ momentum=0.9,
155
+ weight_decay=weight_decay,
156
+ )
157
+ elif optimizer_name.lower() == "adamw":
158
+ optimizer = optim.AdamW(
159
+ model.parameters(), lr=learning_rate, eps=eps, weight_decay=weight_decay
160
+ )
161
+ else:
162
+ print(f"Unknown optimizer: {optimizer_name}, defaulting to Adam")
163
+ optimizer = optim.Adam(
164
+ model.parameters(), lr=learning_rate, eps=eps, weight_decay=weight_decay
165
+ )
166
+
167
+ criterion = nn.CrossEntropyLoss()
168
+
169
+ print(
170
+ f"Training with: batch_size={batch_size}, hidden_size={hidden_size}, "
171
+ f"epochs={num_epochs}, lr={learning_rate}, optimizer={optimizer_name}"
172
+ )
173
+
174
+ # Training loop
175
+ for epoch in range(1, num_epochs + 1):
176
+ train_loss = train_epoch(
177
+ model, device, train_loader, optimizer, criterion, precision
178
+ )
179
+
180
+ # Check for NaN loss
181
+ if torch.isnan(torch.tensor([train_loss])):
182
+ print(f"NaN detected in epoch {epoch}, reducing learning rate")
183
+ for param_group in optimizer.param_groups:
184
+ param_group["lr"] *= 0.5
185
+
186
+ print(f"Epoch {epoch} Train Loss: {train_loss:.4f}")
187
+
188
+ # Evaluation on test set
189
+ test_loss, test_accuracy = evaluate(
190
+ model, device, test_loader, criterion, precision
191
+ )
192
+ print(
193
+ f"{precision.upper()} Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%"
194
+ )
195
+
196
+ # Optionally, save the model
197
+ if model_save_path:
198
+ save_path = model_save_path
199
+ else:
200
+ model_dir = "models"
201
+ os.makedirs(model_dir, exist_ok=True)
202
+ save_path = os.path.join(model_dir, f"mlp_mnist_{precision}.pth")
203
+
204
+ torch.save(model.state_dict(), save_path)
205
+ print(f"Model saved to {save_path}\n")
206
+
207
+
208
+ # Main script to train a model in a specific precision
209
+ if __name__ == "__main__":
210
+ import argparse
211
+
212
+ parser = argparse.ArgumentParser(
213
+ description="Train MNIST model in a specific precision"
214
+ )
215
+ parser.add_argument(
216
+ "--dtype",
217
+ type=str,
218
+ default="fp32",
219
+ choices=["fp32", "fp16", "bf16", "fp8"],
220
+ help="Precision type to train in (fp32, fp16, bf16, fp8)",
221
+ )
222
+ parser.add_argument(
223
+ "--batch-size", type=int, default=32, help="Batch size for training"
224
+ )
225
+ parser.add_argument(
226
+ "--hidden-size", type=int, default=128, help="Hidden layer size for MLP"
227
+ )
228
+ parser.add_argument(
229
+ "--epochs", type=int, default=5, help="Number of training epochs"
230
+ )
231
+ parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
232
+ parser.add_argument(
233
+ "--optimizer",
234
+ type=str,
235
+ default="adam",
236
+ choices=["adam", "sgd", "adamw"],
237
+ help="Optimizer to use for training",
238
+ )
239
+ parser.add_argument(
240
+ "--weight-decay", type=float, default=0, help="Weight decay (L2 penalty)"
241
+ )
242
+ parser.add_argument(
243
+ "--eps", type=float, default=1e-4, help="Epsilon for Adam optimizer"
244
+ )
245
+ parser.add_argument(
246
+ "--save-path", type=str, default=None, help="Path to save the trained model"
247
+ )
248
+
249
+ args = parser.parse_args()
250
+
251
+ try:
252
+ train_model(
253
+ precision=args.dtype,
254
+ batch_size=args.batch_size,
255
+ hidden_size=args.hidden_size,
256
+ num_epochs=args.epochs,
257
+ learning_rate=args.lr,
258
+ optimizer_name=args.optimizer,
259
+ weight_decay=args.weight_decay,
260
+ eps=args.eps,
261
+ model_save_path=args.save_path,
262
+ )
263
+ except Exception as e:
264
+ print(f"Error training {args.dtype.upper()} model: {e}")
hardware_accelerators/nn/precision_eval.py ADDED
@@ -0,0 +1,280 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torchvision import datasets, transforms
4
+ from torch.utils.data import DataLoader
5
+ import numpy as np
6
+ import time
7
+ from pathlib import Path
8
+
9
+ from .precision import MLP
10
+ from .util import get_pytorch_device
11
+
12
+
13
+ def load_mnist_data(batch_size=100):
14
+ """Load MNIST test dataset"""
15
+ transform = transforms.Compose(
16
+ [transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
17
+ )
18
+ test_dataset = datasets.MNIST(
19
+ root="./data", train=False, download=True, transform=transform
20
+ )
21
+ test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
22
+ return test_loader
23
+
24
+
25
+ def create_model(precision):
26
+ """Create MLP model with specified precision"""
27
+ input_size = 28 * 28 # MNIST images are 28x28
28
+ hidden_size = 128
29
+ num_classes = 10
30
+ device = get_pytorch_device()
31
+
32
+ model = MLP(input_size, hidden_size, num_classes).to(device)
33
+
34
+ # Convert model to target precision
35
+ if precision == "fp16":
36
+ model = model.to(torch.float16)
37
+ elif precision == "bf16":
38
+ model = model.to(torch.bfloat16)
39
+ elif precision == "fp32":
40
+ model = model.to(torch.float32)
41
+
42
+ return model, device
43
+
44
+
45
+ def load_model_weights(model, model_path):
46
+ """Load model weights from checkpoint"""
47
+ model.load_state_dict(torch.load(model_path))
48
+ return model
49
+
50
+
51
+ def evaluate_model(model, test_loader, device, precision):
52
+ """Evaluate model accuracy and inference time"""
53
+ model.eval()
54
+ correct = 0
55
+ total = 0
56
+
57
+ # For measuring inference time
58
+ start_time = time.time()
59
+
60
+ with torch.no_grad():
61
+ for data, target in test_loader:
62
+ # Convert input to specified precision
63
+ if precision == "fp16":
64
+ data = data.half()
65
+ elif precision == "bf16":
66
+ data = data.to(torch.bfloat16)
67
+
68
+ data, target = data.to(device), target.to(device)
69
+
70
+ # Forward pass
71
+ outputs = model(data)
72
+
73
+ # Calculate accuracy
74
+ _, predicted = torch.max(outputs, 1)
75
+ total += target.size(0)
76
+ correct += (predicted == target).sum().item()
77
+
78
+ inference_time = time.time() - start_time
79
+ accuracy = 100.0 * correct / total
80
+
81
+ return {
82
+ "accuracy": accuracy,
83
+ "inference_time": inference_time,
84
+ "correct": correct,
85
+ "total": total,
86
+ }
87
+
88
+
89
+ def compare_precision_inference(fp32_model_path):
90
+ """Compare FP32 trained model inference in different precisions"""
91
+ print("\nEvaluating FP32-trained model inference with different precisions")
92
+ print("-" * 80)
93
+
94
+ # Load test data
95
+ test_loader = load_mnist_data()
96
+
97
+ # Verify the model file exists
98
+ model_path = Path(fp32_model_path)
99
+ if not model_path.exists():
100
+ print(f"Error: Model file {fp32_model_path} not found!")
101
+ return
102
+
103
+ # Create models in different precisions
104
+ precisions = ["fp32", "bf16"]
105
+ results = {}
106
+
107
+ for precision in precisions:
108
+ print(f"\nTesting inference in {precision.upper()} mode...")
109
+
110
+ # Create model in specified precision
111
+ model, device = create_model(precision)
112
+
113
+ # Load weights from the FP32-trained model
114
+ # When loading to a BF16 model, the weights will be automatically cast
115
+ model = load_model_weights(model, fp32_model_path)
116
+
117
+ # Evaluate model
118
+ results[precision] = evaluate_model(model, test_loader, device, precision)
119
+
120
+ # Print results
121
+ print(f"Accuracy: {results[precision]['accuracy']:.2f}%")
122
+ print(f"Inference time: {results[precision]['inference_time']:.4f} seconds")
123
+
124
+ # Calculate and print comparison metrics
125
+ print("\nComparison Summary")
126
+ print("-" * 80)
127
+
128
+ fp32_results = results["fp32"]
129
+ bf16_results = results["bf16"]
130
+
131
+ acc_diff = bf16_results["accuracy"] - fp32_results["accuracy"]
132
+ time_ratio = fp32_results["inference_time"] / bf16_results["inference_time"]
133
+
134
+ print(f"Accuracy drop from FP32 to BF16: {acc_diff:.2f}%")
135
+ print(f"Inference speedup with BF16: {time_ratio:.2f}x")
136
+
137
+ return results
138
+
139
+
140
+ def detailed_precision_comparison(
141
+ fp32_model_path, trials=3, batch_sizes=[1, 16, 32, 64, 128, 256]
142
+ ):
143
+ """Run detailed comparison with multiple batch sizes and trials"""
144
+ print("\nDetailed Precision Comparison")
145
+ print("=" * 80)
146
+
147
+ # Verify the model file exists
148
+ model_path = Path(fp32_model_path)
149
+ if not model_path.exists():
150
+ print(f"Error: Model file {fp32_model_path} not found!")
151
+ return
152
+
153
+ precisions = ["fp32", "bf16"]
154
+ all_results = {}
155
+
156
+ for precision in precisions:
157
+ all_results[precision] = {}
158
+
159
+ print(f"\nEvaluating {precision.upper()} precision")
160
+ print("-" * 60)
161
+
162
+ # Create model in specified precision only once
163
+ model, device = create_model(precision)
164
+ model = load_model_weights(model, fp32_model_path)
165
+
166
+ # Warm up the GPU/CPU
167
+ print("Warming up...")
168
+ dummy_loader = load_mnist_data(batch_size=64)
169
+ with torch.no_grad():
170
+ for data, _ in dummy_loader:
171
+ if precision == "bf16":
172
+ data = data.to(torch.bfloat16)
173
+ elif precision == "fp16":
174
+ data = data.half()
175
+ data = data.to(device)
176
+ _ = model(data)
177
+ break
178
+
179
+ # Run trials for different batch sizes
180
+ for batch_size in batch_sizes:
181
+ print(f"\n Batch size: {batch_size}")
182
+ test_loader = load_mnist_data(batch_size=batch_size)
183
+
184
+ batch_results = {"accuracy": [], "inference_time": []}
185
+
186
+ for trial in range(trials):
187
+ print(f" Trial {trial+1}/{trials}...", end="", flush=True)
188
+ result = evaluate_model(model, test_loader, device, precision)
189
+ batch_results["accuracy"].append(result["accuracy"])
190
+ batch_results["inference_time"].append(result["inference_time"])
191
+ print(
192
+ f" done. Time: {result['inference_time']:.4f}s, Accuracy: {result['accuracy']:.2f}%"
193
+ )
194
+
195
+ # Calculate averages
196
+ avg_accuracy = sum(batch_results["accuracy"]) / len(
197
+ batch_results["accuracy"]
198
+ )
199
+ avg_time = sum(batch_results["inference_time"]) / len(
200
+ batch_results["inference_time"]
201
+ )
202
+
203
+ all_results[precision][batch_size] = {
204
+ "avg_accuracy": avg_accuracy,
205
+ "avg_inference_time": avg_time,
206
+ "trials": batch_results,
207
+ }
208
+
209
+ print(f" Average: {avg_time:.4f}s, Accuracy: {avg_accuracy:.2f}%")
210
+
211
+ # Print comparison table
212
+ print("\nComparison Results")
213
+ print("=" * 80)
214
+ print(
215
+ f"{'Batch Size':^10} | {'FP32 Acc':^10} | {'BF16 Acc':^10} | {'Acc Diff':^10} | {'FP32 Time':^10} | {'BF16 Time':^10} | {'Speedup':^10}"
216
+ )
217
+ print("-" * 80)
218
+
219
+ for batch_size in batch_sizes:
220
+ fp32_acc = all_results["fp32"][batch_size]["avg_accuracy"]
221
+ bf16_acc = all_results["bf16"][batch_size]["avg_accuracy"]
222
+ acc_diff = bf16_acc - fp32_acc
223
+
224
+ fp32_time = all_results["fp32"][batch_size]["avg_inference_time"]
225
+ bf16_time = all_results["bf16"][batch_size]["avg_inference_time"]
226
+ speedup = fp32_time / bf16_time
227
+
228
+ print(
229
+ f"{batch_size:^10} | {fp32_acc:^10.2f} | {bf16_acc:^10.2f} | {acc_diff:^10.2f} | {fp32_time:^10.4f} | {bf16_time:^10.4f} | {speedup:^10.2f}x"
230
+ )
231
+
232
+ # Calculate and print overall averages
233
+ avg_acc_diff = sum(
234
+ all_results["bf16"][bs]["avg_accuracy"]
235
+ - all_results["fp32"][bs]["avg_accuracy"]
236
+ for bs in batch_sizes
237
+ ) / len(batch_sizes)
238
+ avg_speedup = sum(
239
+ all_results["fp32"][bs]["avg_inference_time"]
240
+ / all_results["bf16"][bs]["avg_inference_time"]
241
+ for bs in batch_sizes
242
+ ) / len(batch_sizes)
243
+
244
+ print("-" * 80)
245
+ print(f"Average accuracy difference across all batch sizes: {avg_acc_diff:.2f}%")
246
+ print(f"Average speedup across all batch sizes: {avg_speedup:.2f}x")
247
+
248
+ return all_results
249
+
250
+
251
+ if __name__ == "__main__":
252
+ import argparse
253
+
254
+ parser = argparse.ArgumentParser(
255
+ description="Compare model inference with FP32 vs BF16"
256
+ )
257
+ parser.add_argument(
258
+ "--model_path",
259
+ type=str,
260
+ default="models/mlp_mnist_fp32.pth",
261
+ help="Path to FP32 trained model weights",
262
+ )
263
+ parser.add_argument(
264
+ "--detailed",
265
+ action="store_true",
266
+ help="Run detailed comparison with multiple batch sizes",
267
+ )
268
+ parser.add_argument(
269
+ "--trials",
270
+ type=int,
271
+ default=3,
272
+ help="Number of trials to run (only with --detailed)",
273
+ )
274
+
275
+ args = parser.parse_args()
276
+
277
+ if args.detailed:
278
+ detailed_precision_comparison(args.model_path, trials=args.trials)
279
+ else:
280
+ compare_precision_inference(args.model_path)
hardware_accelerators/nn/run_precision_comparison.py ADDED
@@ -0,0 +1,78 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+
3
+ import os
4
+ import argparse
5
+ from hardware_accelerators.nn.precision import train_model
6
+ from hardware_accelerators.nn.precision_eval import (
7
+ compare_precision_inference,
8
+ detailed_precision_comparison,
9
+ )
10
+
11
+
12
+ def main():
13
+ # Parse command line arguments
14
+ parser = argparse.ArgumentParser(
15
+ description="Train and evaluate precision differences"
16
+ )
17
+ parser.add_argument(
18
+ "--detailed",
19
+ action="store_true",
20
+ help="Run detailed comparison with multiple batch sizes and trials",
21
+ )
22
+ parser.add_argument(
23
+ "--trials",
24
+ type=int,
25
+ default=3,
26
+ help="Number of trials to run for each batch size (only with --detailed)",
27
+ )
28
+ parser.add_argument(
29
+ "--force-train",
30
+ action="store_true",
31
+ help="Force training a new model even if one exists",
32
+ )
33
+ parser.add_argument(
34
+ "--batch-sizes",
35
+ nargs="+",
36
+ type=int,
37
+ default=[1, 16, 32, 64, 128, 256],
38
+ help="Batch sizes to test (only with --detailed)",
39
+ )
40
+ args = parser.parse_args()
41
+
42
+ # Create models directory if it doesn't exist
43
+ os.makedirs("models", exist_ok=True)
44
+
45
+ # Path to save/load model
46
+ model_path = "models/mlp_mnist_fp32.pth"
47
+
48
+ # Check if model exists, train if not or if forced
49
+ if not os.path.exists(model_path) or args.force_train:
50
+ print("Training a new FP32 model...")
51
+ # Train a model in FP32 precision
52
+ train_model(
53
+ precision="fp32",
54
+ batch_size=64,
55
+ hidden_size=128,
56
+ num_epochs=5,
57
+ learning_rate=0.001,
58
+ optimizer_name="adam",
59
+ model_save_path=model_path,
60
+ )
61
+ else:
62
+ print(f"Using existing model at {model_path}")
63
+
64
+ # Run comparison
65
+ if args.detailed:
66
+ print(
67
+ f"Running detailed comparison with {args.trials} trials for each batch size..."
68
+ )
69
+ detailed_precision_comparison(
70
+ model_path, trials=args.trials, batch_sizes=args.batch_sizes
71
+ )
72
+ else:
73
+ # Run basic comparison
74
+ compare_precision_inference(model_path)
75
+
76
+
77
+ if __name__ == "__main__":
78
+ main()
hardware_accelerators/nn/train.py CHANGED
@@ -8,8 +8,6 @@ from tqdm.auto import tqdm
8
 
9
  from .util import model_factory, get_pytorch_device # progress bar for notebooks
10
 
11
- # from pytorch2tikz import Architecture
12
-
13
 
14
  # Training function for one epoch
15
  def train(model, device, train_loader, optimizer, criterion, epoch, num_epochs):
 
8
 
9
  from .util import model_factory, get_pytorch_device # progress bar for notebooks
10
 
 
 
11
 
12
  # Training function for one epoch
13
  def train(model, device, train_loader, optimizer, criterion, epoch, num_epochs):
hardware_accelerators/nn/util.py CHANGED
@@ -27,8 +27,10 @@ def load_model(model_path: str, device: torch.device | None = None):
27
  if device is None:
28
  device = get_pytorch_device()
29
  model = model_factory()
 
 
 
30
  model.to(device)
31
- model.load_state_dict(torch.load(model_path, map_location=device))
32
  return model
33
 
34
 
 
27
  if device is None:
28
  device = get_pytorch_device()
29
  model = model_factory()
30
+ model.load_state_dict(
31
+ torch.load(model_path, map_location=device, weights_only=True)
32
+ )
33
  model.to(device)
 
34
  return model
35
 
36
 
hardware_accelerators/rtllib/__init__.py CHANGED
@@ -9,9 +9,13 @@ from .accelerator import (
9
  TiledMatrixEngine,
10
  AcceleratorConfig,
11
  Accelerator,
 
 
 
 
12
  )
13
 
14
- all = [
15
  "float_adder",
16
  "FloatAdderPipelined",
17
  "float_multiplier",
@@ -20,11 +24,15 @@ all = [
20
  "lmul_fast",
21
  "LmulPipelined",
22
  "SystolicArrayDiP",
23
- "AccumulatorMemoryBank",
24
  "BufferMemory",
25
  "WeightFIFO",
26
  "TiledAcceleratorConfig",
27
  "TiledMatrixEngine",
28
  "AcceleratorConfig",
29
  "Accelerator",
 
 
 
 
30
  ]
 
9
  TiledMatrixEngine,
10
  AcceleratorConfig,
11
  Accelerator,
12
+ CompiledAcceleratorConfig,
13
+ CompiledAccelerator,
14
+ AcceleratorAnalysisConfig,
15
+ AcceleratorTopLevel,
16
  )
17
 
18
+ __all__ = [
19
  "float_adder",
20
  "FloatAdderPipelined",
21
  "float_multiplier",
 
24
  "lmul_fast",
25
  "LmulPipelined",
26
  "SystolicArrayDiP",
27
+ "TiledAccumulatorMemoryBank",
28
  "BufferMemory",
29
  "WeightFIFO",
30
  "TiledAcceleratorConfig",
31
  "TiledMatrixEngine",
32
  "AcceleratorConfig",
33
  "Accelerator",
34
+ "CompiledAcceleratorConfig",
35
+ "CompiledAccelerator",
36
+ "AcceleratorAnalysisConfig",
37
+ "AcceleratorTopLevel",
38
  ]
hardware_accelerators/rtllib/accelerator.py CHANGED
@@ -1,99 +1,140 @@
 
1
  from dataclasses import dataclass
2
- from typing import Callable, Type, Dict
3
  import numpy as np
 
 
4
 
5
  from pyrtl import (
6
  WireVector,
7
  Register,
 
8
  Output,
9
  Simulation,
 
10
  concat,
11
  )
12
 
 
13
 
 
 
 
 
 
14
  from .buffer import BufferMemory, WeightFIFO
15
  from .systolic import SystolicArrayDiP
16
  from .accumulators import Accumulator, TiledAccumulatorMemoryBank
17
- from .activations import ReluUnit
18
  from ..dtypes import BaseFloat
19
 
 
20
 
21
- @dataclass
22
- class AcceleratorConfig:
23
- """Configuration class for a systolic array accelerator.
24
 
25
- This class defines the parameters and specifications for a systolic array
26
- accelerator including array dimensions, data types, arithmetic operations,
27
- and memory configuration.
28
- """
29
 
30
  array_size: int
31
- """Dimension of systolic array (always square)"""
32
-
33
- num_weight_tiles: int
34
- """Number of weight tiles in the FIFO. Each tile is equal to the size of the systolic array"""
35
-
36
- data_type: Type[BaseFloat]
37
- """Floating point format of input data to systolic array"""
38
-
39
  weight_type: Type[BaseFloat]
40
- """Floating point format of weight inputs"""
 
 
 
 
 
 
 
 
 
 
 
41
 
42
- accum_type: Type[BaseFloat]
43
- """Floating point format to accumulate values in"""
 
 
 
 
 
 
 
 
 
 
 
44
 
45
- pe_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
46
- """Function to generate adder hardware for the processing elements"""
 
 
 
 
 
 
 
 
 
 
47
 
48
- accum_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
49
- """Function to generate adder hardware for the accumulator buffer"""
50
 
51
- pe_multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
52
- """Function to generate multiplier hardware for the processing elements"""
 
 
 
 
 
 
 
 
 
 
53
 
54
- pipeline: bool
55
- """Whether to add a pipeline stage in processing elements between multiplier and adder"""
 
 
56
 
57
- accum_addr_width: int
58
- """Address width for accumulator memory. Determines number of individually addressable locations"""
59
 
60
  @property
61
- def weight_tile_addr_width(self):
62
- """Get the width of the weight tile address bus in bits"""
63
- return (self.num_weight_tiles - 1).bit_length()
64
 
65
 
66
- class Accelerator:
67
- def __init__(self, config: AcceleratorConfig):
68
  self.config = config
69
 
70
  # Instantiate hardware components
71
- self.fifo = WeightFIFO(
72
- array_size=config.array_size,
73
- num_tiles=config.num_weight_tiles,
74
- dtype=config.weight_type,
75
- )
76
  self.systolic_array = SystolicArrayDiP(
77
  size=config.array_size,
78
- data_type=config.data_type,
79
  weight_type=config.weight_type,
80
- accum_type=config.accum_type,
81
- multiplier=config.pe_multiplier,
82
- adder=config.pe_adder,
83
- pipeline=config.pipeline,
84
  )
85
  self.accumulator = Accumulator(
86
- addr_width=config.accum_addr_width,
87
  array_size=config.array_size,
88
- data_type=config.accum_type,
89
- adder=config.accum_adder,
90
  )
91
  self.activation = ReluUnit(
92
  size=config.array_size,
93
- dtype=config.accum_type,
94
  )
95
  self.outputs = [
96
- WireVector(config.accum_type.bitwidth()) for _ in range(config.array_size)
 
97
  ]
98
 
99
  # Connect components
@@ -103,12 +144,15 @@ class Accelerator:
103
  """Create unnamed WireVectors for control signals"""
104
  self.data_enable = WireVector(1)
105
  self.data_ins = [
106
- WireVector(self.config.data_type.bitwidth())
107
  for _ in range(self.config.array_size)
108
  ]
109
 
110
- self.weight_start_in = WireVector(1)
111
- self.weight_tile_addr_in = WireVector(self.fifo.tile_addr_width)
 
 
 
112
 
113
  self.accum_addr_in = WireVector(self.config.accum_addr_width)
114
  self.accum_mode_in = WireVector(1)
@@ -117,7 +161,7 @@ class Accelerator:
117
  self.act_func_in = WireVector(1) # Apply activation function or passthrough
118
 
119
  def _create_pipeline_registers(self):
120
- num_registers = self.config.array_size + 1 + int(self.config.pipeline)
121
 
122
  self.accum_addr_regs = [
123
  Register(self.config.accum_addr_width) for _ in range(num_registers)
@@ -153,16 +197,11 @@ class Accelerator:
153
  self._create_pipeline_registers()
154
 
155
  # Connect buffer to external inputs
156
- self.fifo.connect_inputs(
157
- start=self.weight_start_in,
158
- tile_addr=self.weight_tile_addr_in,
159
- )
160
-
161
  self.systolic_array.connect_inputs(
162
  data_inputs=self.data_ins,
163
  enable_input=self.data_enable,
164
- weight_inputs=self.fifo.outputs.weights,
165
- weight_enable=self.fifo.outputs.active,
166
  )
167
 
168
  # Connect accumulator to systolic array
@@ -188,8 +227,8 @@ class Accelerator:
188
  self,
189
  data_enable: WireVector | None = None,
190
  data_inputs: list[WireVector] | None = None,
191
- weight_start: WireVector | None = None,
192
- weight_tile_addr: WireVector | None = None,
193
  accum_addr: WireVector | None = None,
194
  accum_mode: WireVector | None = None,
195
  act_start: WireVector | None = None,
@@ -204,9 +243,8 @@ class Accelerator:
204
  Args:
205
  data_enable: 1-bit signal that enables data flow into the systolic array
206
  data_inputs: List of input data wires for the systolic array. Must match array_size
207
- weight_start: 1-bit signal that triggers loading of a new weight tile when pulsed high
208
- weight_tile_addr: Address selecting which weight tile to load from the FIFO.
209
- Width must match the FIFO's tile address width
210
  accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
211
  accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
212
  act_start: 1-bit signal to enable passing data through the activation unit
@@ -226,22 +264,27 @@ class Accelerator:
226
  f"Expected {self.config.array_size}, got {len(data_inputs)}"
227
  )
228
  for i, wire in enumerate(data_inputs):
229
- assert len(wire) == self.config.data_type.bitwidth(), (
230
  f"Data input width mismatch. "
231
- f"Expected {self.config.data_type.bitwidth()}, got {len(wire)}"
232
  )
233
  self.data_ins[i] <<= wire
234
 
235
- if weight_start is not None:
236
- assert len(weight_start) == 1, "Weight start signal must be 1 bit wide"
237
- self.weight_start_in <<= weight_start
238
 
239
- if weight_tile_addr is not None:
240
- assert len(weight_tile_addr) == self.fifo.tile_addr_width, (
241
- f"Weight tile address width mismatch. "
242
- f"Expected {self.fifo.tile_addr_width}, got {len(weight_tile_addr)}"
243
  )
244
- self.weight_tile_addr_in <<= weight_tile_addr
 
 
 
 
 
245
 
246
  if accum_addr is not None:
247
  assert len(accum_addr) == self.config.accum_addr_width, (
@@ -278,12 +321,8 @@ class Accelerator:
278
  assert len(valid) == 1, "Output valid signal must be a single bit wire"
279
  valid <<= self.activation.outputs_valid
280
 
281
- def inspect_systolic_array_state(self, sim: Simulation):
282
- """Return current PE array state"""
283
- return self.systolic_array.get_state(sim)
284
-
285
- def inspect_accumulator_state(self, sim: Simulation) -> np.ndarray:
286
- """Return all accumulator tiles as 3D array.
287
 
288
  Args:
289
  sim: PyRTL simulation instance
@@ -296,18 +335,237 @@ class Accelerator:
296
  tiles = []
297
  for addr in range(2**self.config.accum_addr_width):
298
  row = [
299
- float(self.config.accum_type(binint=sim.inspect_mem(bank).get(addr, 0)))
 
 
 
 
300
  for bank in self.accumulator.memory_banks
301
  ]
302
  tiles.append(row)
303
  return np.array(tiles)
304
 
305
 
306
- class CompiledAccelerator:
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
307
  def __init__(self, config: AcceleratorConfig):
308
  self.config = config
309
 
310
  # Instantiate hardware components
 
 
 
 
 
311
  self.systolic_array = SystolicArrayDiP(
312
  size=config.array_size,
313
  data_type=config.data_type,
@@ -342,11 +600,8 @@ class CompiledAccelerator:
342
  for _ in range(self.config.array_size)
343
  ]
344
 
345
- self.weight_enable = WireVector(1)
346
- self.weights_in = [
347
- WireVector(self.config.weight_type.bitwidth())
348
- for _ in range(self.config.array_size)
349
- ]
350
 
351
  self.accum_addr_in = WireVector(self.config.accum_addr_width)
352
  self.accum_mode_in = WireVector(1)
@@ -367,23 +622,33 @@ class CompiledAccelerator:
367
  self.accum_mode_out = WireVector(1)
368
  self.accum_mode_out <<= self.accum_mode_regs[-1]
369
 
370
- self.act_control_regs = [Register(2) for _ in range(num_registers)]
371
- self.act_control_regs[0].next <<= concat(self.act_start_in, self.act_func_in)
 
 
 
 
 
372
 
373
  self.accum_addr_regs[0].next <<= self.accum_addr_in
374
  self.accum_mode_regs[0].next <<= self.accum_mode_in
375
  for i in range(1, len(self.accum_addr_regs)):
376
  self.accum_addr_regs[i].next <<= self.accum_addr_regs[i - 1]
377
  self.accum_mode_regs[i].next <<= self.accum_mode_regs[i - 1]
378
- self.act_control_regs[i].next <<= self.act_control_regs[i - 1]
 
 
 
379
 
380
  self.act_addr = Register(self.config.accum_addr_width)
381
  self.act_func = Register(1)
382
  self.act_start = Register(1)
383
 
384
  self.act_addr.next <<= self.accum_addr_out
385
- self.act_func.next <<= self.act_control_regs[-1][0]
386
- self.act_start.next <<= self.act_control_regs[-1][1]
 
 
387
 
388
  def _connect_components(self):
389
  """Internal component connections"""
@@ -391,11 +656,16 @@ class CompiledAccelerator:
391
  self._create_pipeline_registers()
392
 
393
  # Connect buffer to external inputs
 
 
 
 
 
394
  self.systolic_array.connect_inputs(
395
  data_inputs=self.data_ins,
396
  enable_input=self.data_enable,
397
- weight_inputs=self.weights_in,
398
- weight_enable=self.weight_enable,
399
  )
400
 
401
  # Connect accumulator to systolic array
@@ -421,8 +691,8 @@ class CompiledAccelerator:
421
  self,
422
  data_enable: WireVector | None = None,
423
  data_inputs: list[WireVector] | None = None,
424
- weight_enable: WireVector | None = None,
425
- weights_in: list[WireVector] | None = None,
426
  accum_addr: WireVector | None = None,
427
  accum_mode: WireVector | None = None,
428
  act_start: WireVector | None = None,
@@ -437,8 +707,9 @@ class CompiledAccelerator:
437
  Args:
438
  data_enable: 1-bit signal that enables data flow into the systolic array
439
  data_inputs: List of input data wires for the systolic array. Must match array_size
440
- weight_enable: 1-bit signal enable writing new weights to systolic array registers
441
- weights_in: List of input weight wires for the systolic array. Must match array_size
 
442
  accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
443
  accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
444
  act_start: 1-bit signal to enable passing data through the activation unit
@@ -464,21 +735,16 @@ class CompiledAccelerator:
464
  )
465
  self.data_ins[i] <<= wire
466
 
467
- if weight_enable is not None:
468
- assert len(weight_enable) == 1, "Weight start signal must be 1 bit wide"
469
- self.weight_enable <<= weight_enable
470
 
471
- if weights_in is not None:
472
- assert len(weights_in) == self.config.array_size, (
473
- f"Weights input list length must match array size. "
474
- f"Expected {self.config.array_size}, got {len(weights_in)}"
475
  )
476
- for i, wire in enumerate(weights_in):
477
- assert len(wire) == self.config.weight_type.bitwidth(), (
478
- f"Weight input wire width mismatch. "
479
- f"Expected {self.config.weight_type.bitwidth()}, got {len(wire)}"
480
- )
481
- self.weights_in[i] <<= wire
482
 
483
  if accum_addr is not None:
484
  assert len(accum_addr) == self.config.accum_addr_width, (
@@ -515,6 +781,34 @@ class CompiledAccelerator:
515
  assert len(valid) == 1, "Output valid signal must be a single bit wire"
516
  valid <<= self.activation.outputs_valid
517
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
518
 
519
  @dataclass
520
  class TiledAcceleratorConfig:
 
1
+ from __future__ import annotations
2
  from dataclasses import dataclass
3
+ from typing import Callable, Literal, Type, Dict
4
  import numpy as np
5
+ import hashlib
6
+ import json
7
 
8
  from pyrtl import (
9
  WireVector,
10
  Register,
11
+ Input,
12
  Output,
13
  Simulation,
14
+ CompiledSimulation,
15
  concat,
16
  )
17
 
18
+ from .adders import float_adder
19
 
20
+ from ..dtypes.bfloat16 import BF16
21
+
22
+ from .multipliers import *
23
+ from .adders import *
24
+ from .lmul import *
25
  from .buffer import BufferMemory, WeightFIFO
26
  from .systolic import SystolicArrayDiP
27
  from .accumulators import Accumulator, TiledAccumulatorMemoryBank
28
+ from .activations import ReluState, ReluUnit
29
  from ..dtypes import BaseFloat
30
 
31
+ from dataclasses import dataclass
32
 
 
 
 
33
 
34
+ @dataclass # (frozen=True)
35
+ class CompiledAcceleratorConfig:
36
+ """Configuration for a compiled accelerator."""
 
37
 
38
  array_size: int
39
+ activation_type: Type[BaseFloat]
 
 
 
 
 
 
 
40
  weight_type: Type[BaseFloat]
41
+ multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
42
+ accum_addr_width: int = 12 # 4096 accumulator slots
43
+ pipeline_pe: bool = False
44
+
45
+ def __post_init__(self):
46
+ """Validate configuration after initialization."""
47
+ # Ensure activation dtype has bitwidth >= weight dtype
48
+ if self.activation_type.bitwidth() < self.weight_type.bitwidth():
49
+ raise ValueError(
50
+ f"Activation dtype bitwidth ({self.activation_type.bitwidth()}) must be greater than or equal to "
51
+ f"weight dtype bitwidth ({self.weight_type.bitwidth()})"
52
+ )
53
 
54
+ @property
55
+ def name(self):
56
+ dtype_name = lambda d: d.bitwidth() if d != BF16 else "b16"
57
+ lmul = "-lmul" if "lmul" in self.multiplier.__name__.lower() else ""
58
+ mem = f"-m{self.accum_addr_width}" if self.accum_addr_width != 12 else ""
59
+ return (
60
+ f"w{dtype_name(self.weight_type)}"
61
+ f"a{dtype_name(self.activation_type)}"
62
+ f"-{self.array_size}x{self.array_size}"
63
+ f"{lmul}"
64
+ f"{'-p' if self.pipeline_pe else ''}"
65
+ f"{mem}"
66
+ )
67
 
68
+ def __repr__(self) -> str:
69
+ return (
70
+ "CompiledAcceleratorConfig(\n"
71
+ f"\tarray_size={self.array_size}\n"
72
+ f"\tactivation_type={self.activation_type.__name__}\n"
73
+ f"\tweight_type={self.weight_type.__name__}\n"
74
+ f"\tmultiplier={self.multiplier.__name__}\n"
75
+ f"\taccum_addr_width={self.accum_addr_width}\n"
76
+ f"\tpipeline={self.pipeline_pe}\n"
77
+ # f'\tname="{self.name}"\n'
78
+ ")"
79
+ )
80
 
81
+ def __hash__(self) -> int:
82
+ """Generate a consistent hash value for this configuration.
83
 
84
+ Returns:
85
+ An integer hash value.
86
+ """
87
+ # Create a dictionary of the key configuration parameters
88
+ config_dict = {
89
+ "array_size": self.array_size,
90
+ "activation_type": f"{self.activation_type.__module__}.{self.activation_type.__name__}",
91
+ "weight_type": f"{self.weight_type.__module__}.{self.weight_type.__name__}",
92
+ "multiplier": self.multiplier.__name__,
93
+ "accum_addr_width": self.accum_addr_width,
94
+ "pipeline": self.pipeline_pe,
95
+ }
96
 
97
+ # Generate a hash from the sorted JSON representation
98
+ hash_str = hashlib.sha256(
99
+ json.dumps(config_dict, sort_keys=True).encode()
100
+ ).hexdigest()
101
 
102
+ # Convert the first 16 characters of the hex string to an integer
103
+ return int(hash_str[:16], 16)
104
 
105
  @property
106
+ def id(self):
107
+ """Get a unique hexadecimal identifier for this configuration."""
108
+ return hex(self.__hash__())[2:]
109
 
110
 
111
+ class CompiledAccelerator:
112
+ def __init__(self, config: CompiledAcceleratorConfig):
113
  self.config = config
114
 
115
  # Instantiate hardware components
 
 
 
 
 
116
  self.systolic_array = SystolicArrayDiP(
117
  size=config.array_size,
118
+ data_type=config.activation_type,
119
  weight_type=config.weight_type,
120
+ accum_type=config.activation_type,
121
+ multiplier=config.multiplier,
122
+ adder=float_adder,
123
+ pipeline=config.pipeline_pe,
124
  )
125
  self.accumulator = Accumulator(
126
+ addr_width=12,
127
  array_size=config.array_size,
128
+ data_type=config.activation_type,
129
+ adder=float_adder,
130
  )
131
  self.activation = ReluUnit(
132
  size=config.array_size,
133
+ dtype=config.activation_type,
134
  )
135
  self.outputs = [
136
+ WireVector(config.activation_type.bitwidth())
137
+ for _ in range(config.array_size)
138
  ]
139
 
140
  # Connect components
 
144
  """Create unnamed WireVectors for control signals"""
145
  self.data_enable = WireVector(1)
146
  self.data_ins = [
147
+ WireVector(self.config.activation_type.bitwidth())
148
  for _ in range(self.config.array_size)
149
  ]
150
 
151
+ self.weight_enable = WireVector(1)
152
+ self.weights_in = [
153
+ WireVector(self.config.weight_type.bitwidth())
154
+ for _ in range(self.config.array_size)
155
+ ]
156
 
157
  self.accum_addr_in = WireVector(self.config.accum_addr_width)
158
  self.accum_mode_in = WireVector(1)
 
161
  self.act_func_in = WireVector(1) # Apply activation function or passthrough
162
 
163
  def _create_pipeline_registers(self):
164
+ num_registers = self.config.array_size + 1 + int(self.config.pipeline_pe)
165
 
166
  self.accum_addr_regs = [
167
  Register(self.config.accum_addr_width) for _ in range(num_registers)
 
197
  self._create_pipeline_registers()
198
 
199
  # Connect buffer to external inputs
 
 
 
 
 
200
  self.systolic_array.connect_inputs(
201
  data_inputs=self.data_ins,
202
  enable_input=self.data_enable,
203
+ weight_inputs=self.weights_in,
204
+ weight_enable=self.weight_enable,
205
  )
206
 
207
  # Connect accumulator to systolic array
 
227
  self,
228
  data_enable: WireVector | None = None,
229
  data_inputs: list[WireVector] | None = None,
230
+ weight_enable: WireVector | None = None,
231
+ weights_in: list[WireVector] | None = None,
232
  accum_addr: WireVector | None = None,
233
  accum_mode: WireVector | None = None,
234
  act_start: WireVector | None = None,
 
243
  Args:
244
  data_enable: 1-bit signal that enables data flow into the systolic array
245
  data_inputs: List of input data wires for the systolic array. Must match array_size
246
+ weight_enable: 1-bit signal enable writing new weights to systolic array registers
247
+ weights_in: List of input weight wires for the systolic array. Must match array_size
 
248
  accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
249
  accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
250
  act_start: 1-bit signal to enable passing data through the activation unit
 
264
  f"Expected {self.config.array_size}, got {len(data_inputs)}"
265
  )
266
  for i, wire in enumerate(data_inputs):
267
+ assert len(wire) == self.config.activation_type.bitwidth(), (
268
  f"Data input width mismatch. "
269
+ f"Expected {self.config.activation_type.bitwidth()}, got {len(wire)}"
270
  )
271
  self.data_ins[i] <<= wire
272
 
273
+ if weight_enable is not None:
274
+ assert len(weight_enable) == 1, "Weight start signal must be 1 bit wide"
275
+ self.weight_enable <<= weight_enable
276
 
277
+ if weights_in is not None:
278
+ assert len(weights_in) == self.config.array_size, (
279
+ f"Weights input list length must match array size. "
280
+ f"Expected {self.config.array_size}, got {len(weights_in)}"
281
  )
282
+ for i, wire in enumerate(weights_in):
283
+ assert len(wire) == self.config.weight_type.bitwidth(), (
284
+ f"Weight input wire width mismatch. "
285
+ f"Expected {self.config.weight_type.bitwidth()}, got {len(wire)}"
286
+ )
287
+ self.weights_in[i] <<= wire
288
 
289
  if accum_addr is not None:
290
  assert len(accum_addr) == self.config.accum_addr_width, (
 
321
  assert len(valid) == 1, "Output valid signal must be a single bit wire"
322
  valid <<= self.activation.outputs_valid
323
 
324
+ def inspect_accumulator_state(self, sim: CompiledSimulation) -> np.ndarray:
325
+ """Return accumulator memory as an array.
 
 
 
 
326
 
327
  Args:
328
  sim: PyRTL simulation instance
 
335
  tiles = []
336
  for addr in range(2**self.config.accum_addr_width):
337
  row = [
338
+ float(
339
+ self.config.activation_type(
340
+ binint=sim.inspect_mem(bank).get(addr, 0)
341
+ )
342
+ )
343
  for bank in self.accumulator.memory_banks
344
  ]
345
  tiles.append(row)
346
  return np.array(tiles)
347
 
348
 
349
+ @dataclass
350
+ class AcceleratorAnalysisConfig:
351
+ """Configuration for an accelerator to be generated for analysis."""
352
+
353
+ array_size: int
354
+ """
355
+ The size of the systolic array (N x N).
356
+ Determines the number of processing elements in the accelerator.
357
+ """
358
+
359
+ weight_type: Type[BaseFloat]
360
+ """
361
+ The floating-point data type for weights.
362
+ Must be a subclass of BaseFloat (e.g., Float8, BF16, Float32).
363
+ """
364
+
365
+ activation_type: Type[BaseFloat]
366
+ """
367
+ The floating-point data type for activations/inputs.
368
+ Must be a subclass of BaseFloat (e.g., Float8, BF16, Float32).
369
+ """
370
+
371
+ lmul: bool
372
+ """
373
+ Whether to use L-mul for multiplication operations.
374
+ If True, uses linear-time multipliers; if False, uses standard IEEE multipliers.
375
+ """
376
+
377
+ pipeline_level: Literal["low", "high"] | None
378
+ """
379
+ The level of pipelining in the accelerator:
380
+ - None: No pipelining (fully combinational design)
381
+ - 'low': Basic pipelining between multiplier and adder in each PE
382
+ - 'high': Full pipelining with pipelined arithmetic units
383
+ """
384
+
385
+ use_fast_internals: bool
386
+ """
387
+ Whether to use faster basic arithmetic implementations with more complex low-level RTL.
388
+ - True: uses optimized arithmetic units from PyRTL's rtllib
389
+ - False: prioritize simplicity over speed
390
+
391
+ WARNING: Setting to True could potentially make final synthesis on the Verilog output worse as the synthesis tools will not be able to infer optimal circuits from the complex low-level RTL.
392
+ """
393
+
394
+ accum_addr_width: int = 12
395
+ """
396
+ The bit width of the accumulator address.
397
+ Determines the size of the accumulator memory (2^width entries).
398
+ Default is 12 bits (4096 entries).
399
+ """
400
+
401
+ def __post_init__(self):
402
+ # Ensure activation dtype has bitwidth >= weight dtype
403
+ if self.activation_type.bitwidth() < self.weight_type.bitwidth():
404
+ raise ValueError(
405
+ f"Activation dtype bitwidth ({self.activation_type.bitwidth()}) must be greater than or equal to "
406
+ f"weight dtype bitwidth ({self.weight_type.bitwidth()})"
407
+ )
408
+
409
+ # Determine if we should use pipelined arithmetic functions
410
+ use_pipelined_funcs = self.pipeline_level == "high"
411
+
412
+ # Set pipeline_pe flag for PE configuration
413
+ # True if any pipeline level is specified (low or high)
414
+ self.pipeline_pe = self.pipeline_level is not None
415
+
416
+ # Multiplier function selection using dictionary mapping
417
+ multiplier_map = {
418
+ # (lmul, use_pipelined_funcs, fast_internals) -> function
419
+ (True, True, True): lmul_pipelined_fast,
420
+ (True, True, False): lmul_pipelined,
421
+ (True, False, True): lmul_fast,
422
+ (True, False, False): lmul_simple,
423
+ (False, True, True): float_multiplier_pipelined_fast_unstable,
424
+ (False, True, False): float_multiplier_pipelined,
425
+ (False, False, True): float_multiplier_fast_unstable,
426
+ (False, False, False): float_multiplier,
427
+ }
428
+
429
+ # Adder function selection using dictionary mapping
430
+ adder_map = {
431
+ # (use_pipelined_funcs, fast_internals) -> function
432
+ (True, True): float_adder_pipelined_fast_unstable,
433
+ (True, False): float_adder_pipelined,
434
+ (False, True): float_adder_fast_unstable,
435
+ (False, False): float_adder,
436
+ }
437
+
438
+ # Select functions using the maps
439
+ self.multiplier_func = multiplier_map[
440
+ (self.lmul, use_pipelined_funcs, self.use_fast_internals)
441
+ ]
442
+ self.adder_func = adder_map[(use_pipelined_funcs, self.use_fast_internals)]
443
+
444
+ @property
445
+ def name(self):
446
+ dtype_name = lambda d: d.bitwidth() if d != BF16 else "b16"
447
+ mul = "-lmul" if self.lmul else "-ieee"
448
+ pipe_name_map = {"low": "-pipePE", "high": "-pipeALL"}
449
+ fast = "-fast" if self.use_fast_internals else ""
450
+ mem = f"-m{self.accum_addr_width}" if self.accum_addr_width != 12 else ""
451
+ return (
452
+ f"w{dtype_name(self.weight_type)}"
453
+ f"a{dtype_name(self.activation_type)}"
454
+ f"-{self.array_size}x{self.array_size}"
455
+ + mem
456
+ + mul
457
+ + fast
458
+ + pipe_name_map.get(self.pipeline_level, "") # type: ignore
459
+ )
460
+
461
+
462
+ class AcceleratorTopLevel(CompiledAccelerator):
463
+ def __init__(self, config: AcceleratorAnalysisConfig):
464
+ self.config = config
465
+
466
+ # Instantiate hardware components
467
+ self.systolic_array = SystolicArrayDiP(
468
+ size=config.array_size,
469
+ data_type=config.activation_type,
470
+ weight_type=config.weight_type,
471
+ accum_type=config.activation_type,
472
+ multiplier=config.multiplier_func,
473
+ adder=config.adder_func,
474
+ pipeline=config.pipeline_pe,
475
+ )
476
+ self.accumulator = Accumulator(
477
+ addr_width=12,
478
+ array_size=config.array_size,
479
+ data_type=config.activation_type,
480
+ adder=config.adder_func,
481
+ )
482
+ self.activation = ReluUnit(
483
+ size=config.array_size,
484
+ dtype=config.activation_type,
485
+ )
486
+ self.outputs = [
487
+ Output(config.activation_type.bitwidth(), f"out_{i}")
488
+ for i in range(config.array_size)
489
+ ]
490
+
491
+ # Connect everything together and create io ports
492
+ self._connect_components()
493
+ self.valid_out = Output(1, "valid_out")
494
+ self.valid_out <<= self.activation.outputs_valid
495
+
496
+ def _create_control_wires(self):
497
+ """Create named Input wires for control signals"""
498
+ self.data_enable = Input(1, "data_enable")
499
+ self.data_ins = [
500
+ Input(self.config.activation_type.bitwidth(), f"data_in_{i}")
501
+ for i in range(self.config.array_size)
502
+ ]
503
+ self.weight_enable = Input(1, "weight_enable")
504
+ self.weights_in = [
505
+ Input(self.config.weight_type.bitwidth(), f"weight_in_{i}")
506
+ for i in range(self.config.array_size)
507
+ ]
508
+ self.accum_addr_in = Input(self.config.accum_addr_width, "accum_addr_in")
509
+ self.accum_mode_in = Input(1, "accum_mode_in")
510
+ self.act_start_in = Input(1, "act_start_in")
511
+ self.act_func_in = Input(1, "act_func_in")
512
+
513
+
514
+ @dataclass(unsafe_hash=True)
515
+ class AcceleratorConfig:
516
+ """Configuration class for a systolic array accelerator.
517
+
518
+ This class defines the parameters and specifications for a systolic array
519
+ accelerator including array dimensions, data types, arithmetic operations,
520
+ and memory configuration.
521
+ """
522
+
523
+ array_size: int
524
+ """Dimension of systolic array (always square)"""
525
+
526
+ num_weight_tiles: int
527
+ """Number of weight tiles in the FIFO. Each tile is equal to the size of the systolic array"""
528
+
529
+ data_type: Type[BaseFloat]
530
+ """Floating point format of input data to systolic array"""
531
+
532
+ weight_type: Type[BaseFloat]
533
+ """Floating point format of weight inputs"""
534
+
535
+ accum_type: Type[BaseFloat]
536
+ """Floating point format to accumulate values in"""
537
+
538
+ pe_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
539
+ """Function to generate adder hardware for the processing elements"""
540
+
541
+ accum_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
542
+ """Function to generate adder hardware for the accumulator buffer"""
543
+
544
+ pe_multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
545
+ """Function to generate multiplier hardware for the processing elements"""
546
+
547
+ pipeline: bool
548
+ """Whether to add a pipeline stage in processing elements between multiplier and adder"""
549
+
550
+ accum_addr_width: int
551
+ """Address width for accumulator memory. Determines number of individually addressable locations"""
552
+
553
+ @property
554
+ def weight_tile_addr_width(self):
555
+ """Get the width of the weight tile address bus in bits"""
556
+ return (self.num_weight_tiles - 1).bit_length()
557
+
558
+
559
+ class Accelerator:
560
  def __init__(self, config: AcceleratorConfig):
561
  self.config = config
562
 
563
  # Instantiate hardware components
564
+ self.fifo = WeightFIFO(
565
+ array_size=config.array_size,
566
+ num_tiles=config.num_weight_tiles,
567
+ dtype=config.weight_type,
568
+ )
569
  self.systolic_array = SystolicArrayDiP(
570
  size=config.array_size,
571
  data_type=config.data_type,
 
600
  for _ in range(self.config.array_size)
601
  ]
602
 
603
+ self.weight_start_in = WireVector(1)
604
+ self.weight_tile_addr_in = WireVector(self.fifo.tile_addr_width)
 
 
 
605
 
606
  self.accum_addr_in = WireVector(self.config.accum_addr_width)
607
  self.accum_mode_in = WireVector(1)
 
622
  self.accum_mode_out = WireVector(1)
623
  self.accum_mode_out <<= self.accum_mode_regs[-1]
624
 
625
+ self.act_start_regs = [Register(1) for _ in range(num_registers)]
626
+ self.act_enable_regs = [Register(1) for _ in range(num_registers)]
627
+ self.act_start_regs[0].next <<= self.act_start_in
628
+ self.act_enable_regs[0].next <<= self.act_func_in
629
+
630
+ # self.act_control_regs = [Register(2) for _ in range(num_registers)]
631
+ # self.act_control_regs[0].next <<= concat(self.act_start_in, self.act_func_in)
632
 
633
  self.accum_addr_regs[0].next <<= self.accum_addr_in
634
  self.accum_mode_regs[0].next <<= self.accum_mode_in
635
  for i in range(1, len(self.accum_addr_regs)):
636
  self.accum_addr_regs[i].next <<= self.accum_addr_regs[i - 1]
637
  self.accum_mode_regs[i].next <<= self.accum_mode_regs[i - 1]
638
+ # self.act_control_regs[i].next <<= self.act_control_regs[i - 1]
639
+ if i < len(self.act_start_regs):
640
+ self.act_enable_regs[i].next <<= self.act_enable_regs[i - 1]
641
+ self.act_start_regs[i].next <<= self.act_start_regs[i - 1]
642
 
643
  self.act_addr = Register(self.config.accum_addr_width)
644
  self.act_func = Register(1)
645
  self.act_start = Register(1)
646
 
647
  self.act_addr.next <<= self.accum_addr_out
648
+ # self.act_func.next <<= self.act_control_regs[-1][0]
649
+ # self.act_start.next <<= self.act_control_regs[-1][1]
650
+ self.act_func.next <<= self.act_enable_regs[-1]
651
+ self.act_start.next <<= self.act_start_regs[-1]
652
 
653
  def _connect_components(self):
654
  """Internal component connections"""
 
656
  self._create_pipeline_registers()
657
 
658
  # Connect buffer to external inputs
659
+ self.fifo.connect_inputs(
660
+ start=self.weight_start_in,
661
+ tile_addr=self.weight_tile_addr_in,
662
+ )
663
+
664
  self.systolic_array.connect_inputs(
665
  data_inputs=self.data_ins,
666
  enable_input=self.data_enable,
667
+ weight_inputs=self.fifo.outputs.weights,
668
+ weight_enable=self.fifo.outputs.active,
669
  )
670
 
671
  # Connect accumulator to systolic array
 
691
  self,
692
  data_enable: WireVector | None = None,
693
  data_inputs: list[WireVector] | None = None,
694
+ weight_start: WireVector | None = None,
695
+ weight_tile_addr: WireVector | None = None,
696
  accum_addr: WireVector | None = None,
697
  accum_mode: WireVector | None = None,
698
  act_start: WireVector | None = None,
 
707
  Args:
708
  data_enable: 1-bit signal that enables data flow into the systolic array
709
  data_inputs: List of input data wires for the systolic array. Must match array_size
710
+ weight_start: 1-bit signal that triggers loading of a new weight tile when pulsed high
711
+ weight_tile_addr: Address selecting which weight tile to load from the FIFO.
712
+ Width must match the FIFO's tile address width
713
  accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
714
  accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
715
  act_start: 1-bit signal to enable passing data through the activation unit
 
735
  )
736
  self.data_ins[i] <<= wire
737
 
738
+ if weight_start is not None:
739
+ assert len(weight_start) == 1, "Weight start signal must be 1 bit wide"
740
+ self.weight_start_in <<= weight_start
741
 
742
+ if weight_tile_addr is not None:
743
+ assert len(weight_tile_addr) == self.fifo.tile_addr_width, (
744
+ f"Weight tile address width mismatch. "
745
+ f"Expected {self.fifo.tile_addr_width}, got {len(weight_tile_addr)}"
746
  )
747
+ self.weight_tile_addr_in <<= weight_tile_addr
 
 
 
 
 
748
 
749
  if accum_addr is not None:
750
  assert len(accum_addr) == self.config.accum_addr_width, (
 
781
  assert len(valid) == 1, "Output valid signal must be a single bit wire"
782
  valid <<= self.activation.outputs_valid
783
 
784
+ def inspect_systolic_array_state(self, sim: Simulation):
785
+ """Return current PE array state"""
786
+ return self.systolic_array.get_state(sim)
787
+
788
+ def inspect_accumulator_state(self, sim: Simulation) -> np.ndarray:
789
+ """Return all accumulator tiles as 3D array.
790
+
791
+ Args:
792
+ sim: PyRTL simulation instance
793
+
794
+ Returns:
795
+ 2D numpy array of shape (2**accum_addr_width, array_size) containing
796
+ all accumulator tile data converted to floating point values.
797
+ Each tile contains array_size rows with array_size columns.
798
+ """
799
+ tiles = []
800
+ for addr in range(2**self.config.accum_addr_width):
801
+ row = [
802
+ float(self.config.accum_type(binint=sim.inspect_mem(bank).get(addr, 0)))
803
+ for bank in self.accumulator.memory_banks
804
+ ]
805
+ tiles.append(row)
806
+ return np.array(tiles)
807
+
808
+ def inspect_activation_state(self, sim: Simulation) -> ReluState:
809
+ """Return current activation unit state"""
810
+ return self.activation.inspect_state(sim)
811
+
812
 
813
  @dataclass
814
  class TiledAcceleratorConfig:
hardware_accelerators/rtllib/activations.py CHANGED
@@ -1,4 +1,6 @@
1
- from typing import Sequence, Type
 
 
2
  from pyrtl import (
3
  WireVector,
4
  Input,
@@ -13,6 +15,32 @@ from pyrtl import (
13
  from ..dtypes.base import BaseFloat
14
 
15
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
  class ReluUnit:
17
  def __init__(self, size: int, dtype: Type[BaseFloat]):
18
  self.size = size
@@ -21,18 +49,24 @@ class ReluUnit:
21
  # Control signals
22
  self.start = WireVector(1) # trigger to latch new enable value
23
  self.enable_in = WireVector(1) # input enable value to latch
24
- self.inputs_valid = WireVector(1) # indicates if inputs are valid
25
  self.enable_reg = Register(1) # stateful enable register
 
 
 
26
 
27
  # Input and output data
28
- self.data = [WireVector(dtype.bitwidth()) for _ in range(size)]
29
- self.outputs = [self.relu(x) for x in self.data]
 
 
 
 
30
  self.outputs_valid = WireVector(1)
31
- self.outputs_valid <<= self.inputs_valid
32
 
33
  def relu(self, x: WireVector):
34
  # Use enable_reg instead of enable wire
35
- pass_condition = self.inputs_valid & (
36
  ~self.enable_reg | (self.enable_reg & ~x[-1])
37
  )
38
  return select(pass_condition, x, Const(0, self.dtype.bitwidth()))
@@ -48,7 +82,7 @@ class ReluUnit:
48
  len(inputs) == self.size
49
  ), f"Activation module input size mismatch. Expected {self.size}, got {len(inputs)}"
50
  for i in range(self.size):
51
- self.data[i] <<= inputs[i]
52
  self.inputs_valid <<= valid
53
  self.enable_in <<= enable
54
  self.start <<= start
@@ -85,6 +119,34 @@ class ReluUnit:
85
  """
86
  return [float(self.dtype(binint=sim.inspect(out.name))) for out in self.outputs]
87
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
  # class ReluUnit:
90
  # def __init__(self, size: int, dtype: Type[BaseFloat]):
 
1
+ from dataclasses import dataclass
2
+ import numpy as np
3
+ from typing import TYPE_CHECKING, Sequence, Type
4
  from pyrtl import (
5
  WireVector,
6
  Input,
 
15
  from ..dtypes.base import BaseFloat
16
 
17
 
18
+ @dataclass
19
+ class ReluState:
20
+ start: int
21
+ enable_in: int
22
+ enable_reg: int
23
+ inputs_valid: int
24
+ inputs: np.ndarray
25
+ registers: np.ndarray
26
+ outputs_valid: int
27
+ outputs: np.ndarray
28
+
29
+ def __repr__(self) -> str:
30
+ """Pretty print the ReLU state"""
31
+ status = "enabled" if self.enable_reg else "disabled"
32
+ valid_str = "(valid)" if self.outputs_valid else "(invalid)"
33
+
34
+ return (
35
+ f"ReLU {status} {valid_str}\n"
36
+ f" Control: start={self.start}, enable_in={self.enable_in}, "
37
+ f"enable_reg={self.enable_reg}, inputs_valid={self.inputs_valid}\n"
38
+ f" Inputs: {np.array2string(self.inputs, precision=4, suppress_small=True)}\n"
39
+ f" Registers: {np.array2string(self.registers, precision=4, suppress_small=True)}\n"
40
+ f" Outputs: {np.array2string(self.outputs, precision=4, suppress_small=True)}"
41
+ )
42
+
43
+
44
  class ReluUnit:
45
  def __init__(self, size: int, dtype: Type[BaseFloat]):
46
  self.size = size
 
49
  # Control signals
50
  self.start = WireVector(1) # trigger to latch new enable value
51
  self.enable_in = WireVector(1) # input enable value to latch
 
52
  self.enable_reg = Register(1) # stateful enable register
53
+ self.inputs_valid = WireVector(1) # indicates if inputs are valid
54
+ self.valid_reg = Register(1) # stateful valid register
55
+ self.valid_reg.next <<= self.inputs_valid
56
 
57
  # Input and output data
58
+ self.data_in = [WireVector(dtype.bitwidth()) for _ in range(size)]
59
+ self.data_regs = [Register(dtype.bitwidth()) for _ in range(size)]
60
+ for data, reg in zip(self.data_in, self.data_regs):
61
+ reg.next <<= data
62
+
63
+ self.outputs = [self.relu(x) for x in self.data_regs]
64
  self.outputs_valid = WireVector(1)
65
+ self.outputs_valid <<= self.valid_reg
66
 
67
  def relu(self, x: WireVector):
68
  # Use enable_reg instead of enable wire
69
+ pass_condition = self.valid_reg & (
70
  ~self.enable_reg | (self.enable_reg & ~x[-1])
71
  )
72
  return select(pass_condition, x, Const(0, self.dtype.bitwidth()))
 
82
  len(inputs) == self.size
83
  ), f"Activation module input size mismatch. Expected {self.size}, got {len(inputs)}"
84
  for i in range(self.size):
85
+ self.data_in[i] <<= inputs[i]
86
  self.inputs_valid <<= valid
87
  self.enable_in <<= enable
88
  self.start <<= start
 
119
  """
120
  return [float(self.dtype(binint=sim.inspect(out.name))) for out in self.outputs]
121
 
122
+ def inspect_state(self, sim: Simulation) -> ReluState:
123
+ """Inspect current state of the ReLU unit."""
124
+ return ReluState(
125
+ start=sim.inspect(self.start.name),
126
+ enable_in=sim.inspect(self.enable_in.name),
127
+ enable_reg=sim.inspect(self.enable_reg.name),
128
+ inputs_valid=sim.inspect(self.inputs_valid.name),
129
+ inputs=np.array(
130
+ [
131
+ float(self.dtype(binint=sim.inspect(inp.name)))
132
+ for inp in self.data_in
133
+ ]
134
+ ),
135
+ registers=np.array(
136
+ [
137
+ float(self.dtype(binint=sim.inspect(reg.name)))
138
+ for reg in self.data_regs
139
+ ]
140
+ ),
141
+ outputs_valid=sim.inspect(self.outputs_valid.name),
142
+ outputs=np.array(
143
+ [
144
+ float(self.dtype(binint=sim.inspect(out.name)))
145
+ for out in self.outputs
146
+ ]
147
+ ),
148
+ )
149
+
150
 
151
  # class ReluUnit:
152
  # def __init__(self, size: int, dtype: Type[BaseFloat]):
hardware_accelerators/rtllib/adders.py CHANGED
@@ -17,6 +17,7 @@ def float_adder(
17
  float_a: WireVector,
18
  float_b: WireVector,
19
  dtype: Type[BaseFloat],
 
20
  ) -> WireVector:
21
 
22
  e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
@@ -26,7 +27,7 @@ def float_adder(
26
  )
27
 
28
  sign_xor, exp_larger, signed_shift, mant_smaller, mant_larger = adder_stage_2(
29
- sign_a, sign_b, exp_a, exp_b, mantissa_a, mantissa_b, e_bits, m_bits
30
  )
31
 
32
  abs_shift = WireVector(e_bits) # , "abs_shift")
@@ -37,7 +38,7 @@ def float_adder(
37
  )
38
 
39
  mantissa_sum, is_neg, lzc = adder_stage_4(
40
- aligned_mant_msb, mant_larger, sign_xor, m_bits
41
  )
42
 
43
  final_sign, final_exp, norm_mantissa = adder_stage_5(
@@ -56,13 +57,52 @@ def float_adder(
56
  )
57
 
58
  float_result = WireVector(dtype.bitwidth()) # , "float_result")
59
- float_result <<= pyrtl.concat(final_sign, final_exp, norm_mantissa)
 
 
 
 
 
 
 
 
 
 
 
 
60
  return float_result
61
 
62
 
 
 
 
 
 
 
63
  ### ===================================================================
64
  ### Simple Pipeline Design
65
  ### ===================================================================
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
66
 
67
 
68
  class FloatAdderPipelined(SimplePipeline):
@@ -72,6 +112,7 @@ class FloatAdderPipelined(SimplePipeline):
72
  float_b: WireVector,
73
  w_en: WireVector,
74
  dtype: Type[BaseFloat],
 
75
  ):
76
  """
77
  Initialize a pipelined BFloat16 adder with write enable control.
@@ -134,17 +175,17 @@ class FloatAdderPipelined(SimplePipeline):
134
  write enable is not 1 bit
135
  """
136
  assert (
137
- len(float_a) == len(float_b) == 16
138
  ), f"float inputs must be {dtype.bitwidth()} bits"
139
  assert len(w_en) == 1, "write enable signal must be 1 bit"
140
-
141
  self.e_bits, self.m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
142
  # Define inputs and outputs
143
  self._float_a, self._float_b = float_a, float_b
144
  self._write_enable = w_en
145
- # self._result = pyrtl.Register(self.e_bits + self.m_bits + 1, 'result')
146
  self._result_out = pyrtl.WireVector(dtype.bitwidth()) # , "_result")
147
- super(FloatAdderPipelined, self).__init__()
148
 
149
  @property
150
  def result(self):
@@ -183,6 +224,7 @@ class FloatAdderPipelined(SimplePipeline):
183
  self.mant_b,
184
  self.e_bits,
185
  self.m_bits,
 
186
  )
187
 
188
  def stage2(self):
@@ -219,7 +261,11 @@ class FloatAdderPipelined(SimplePipeline):
219
 
220
  # Perform mantissa addition and leading zero detection
221
  self.mant_sum, self.is_neg, self.lzc = adder_stage_4(
222
- self.aligned_mant_msb, self.mant_larger, self.sign_xor, self.m_bits
 
 
 
 
223
  )
224
 
225
  def stage4(self):
 
17
  float_a: WireVector,
18
  float_b: WireVector,
19
  dtype: Type[BaseFloat],
20
+ fast: bool = False,
21
  ) -> WireVector:
22
 
23
  e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
 
27
  )
28
 
29
  sign_xor, exp_larger, signed_shift, mant_smaller, mant_larger = adder_stage_2(
30
+ sign_a, sign_b, exp_a, exp_b, mantissa_a, mantissa_b, e_bits, m_bits, fast
31
  )
32
 
33
  abs_shift = WireVector(e_bits) # , "abs_shift")
 
38
  )
39
 
40
  mantissa_sum, is_neg, lzc = adder_stage_4(
41
+ aligned_mant_msb, mant_larger, sign_xor, m_bits, fast
42
  )
43
 
44
  final_sign, final_exp, norm_mantissa = adder_stage_5(
 
57
  )
58
 
59
  float_result = WireVector(dtype.bitwidth()) # , "float_result")
60
+
61
+ # Zero detection logic
62
+ a_is_zero = ~pyrtl.or_all_bits(float_a[:-1])
63
+ b_is_zero = ~pyrtl.or_all_bits(float_b[:-1])
64
+
65
+ with pyrtl.conditional_assignment:
66
+ with a_is_zero:
67
+ float_result |= float_b
68
+ with b_is_zero:
69
+ float_result |= float_a
70
+ with pyrtl.otherwise:
71
+ float_result |= pyrtl.concat(final_sign, final_exp, norm_mantissa)
72
+
73
  return float_result
74
 
75
 
76
+ def float_adder_fast_unstable(
77
+ float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]
78
+ ) -> WireVector:
79
+ return float_adder(float_a, float_b, dtype, fast=True)
80
+
81
+
82
  ### ===================================================================
83
  ### Simple Pipeline Design
84
  ### ===================================================================
85
+ # TODO: add zero detection logic
86
+
87
+
88
+ def float_adder_pipelined(
89
+ float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat], fast: bool = False
90
+ ) -> WireVector:
91
+ w_en = pyrtl.Input(1)
92
+ w_en.name = w_en.name.replace("tmp", "adder_w_en_in")
93
+ adder = FloatAdderPipelined(float_a, float_b, w_en, dtype, fast=fast)
94
+ return adder._result_out
95
+
96
+
97
+ def float_adder_pipelined_fast_unstable(
98
+ float_a: WireVector,
99
+ float_b: WireVector,
100
+ dtype: Type[BaseFloat],
101
+ ) -> WireVector:
102
+ w_en = pyrtl.Input(1)
103
+ w_en.name = w_en.name.replace("tmp", "adder_w_en_in")
104
+ adder = FloatAdderPipelined(float_a, float_b, w_en, dtype, fast=True)
105
+ return adder._result_out
106
 
107
 
108
  class FloatAdderPipelined(SimplePipeline):
 
112
  float_b: WireVector,
113
  w_en: WireVector,
114
  dtype: Type[BaseFloat],
115
+ fast: bool = False,
116
  ):
117
  """
118
  Initialize a pipelined BFloat16 adder with write enable control.
 
175
  write enable is not 1 bit
176
  """
177
  assert (
178
+ len(float_a) == len(float_b) == dtype.bitwidth()
179
  ), f"float inputs must be {dtype.bitwidth()} bits"
180
  assert len(w_en) == 1, "write enable signal must be 1 bit"
181
+ self._fast = fast
182
  self.e_bits, self.m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
183
  # Define inputs and outputs
184
  self._float_a, self._float_b = float_a, float_b
185
  self._write_enable = w_en
186
+ # self._result = pyrtl.Register(self.e_bits + self.m_bits + 1, "result")
187
  self._result_out = pyrtl.WireVector(dtype.bitwidth()) # , "_result")
188
+ super().__init__("float_adder")
189
 
190
  @property
191
  def result(self):
 
224
  self.mant_b,
225
  self.e_bits,
226
  self.m_bits,
227
+ self._fast,
228
  )
229
 
230
  def stage2(self):
 
261
 
262
  # Perform mantissa addition and leading zero detection
263
  self.mant_sum, self.is_neg, self.lzc = adder_stage_4(
264
+ self.aligned_mant_msb,
265
+ self.mant_larger,
266
+ self.sign_xor,
267
+ self.m_bits,
268
+ self._fast,
269
  )
270
 
271
  def stage4(self):
hardware_accelerators/rtllib/legacy.py CHANGED
@@ -1,11 +1,82 @@
 
1
  import pyrtl
 
2
  from pyrtl.rtllib.adders import carrysave_adder, kogge_stone
3
 
 
 
4
  from .utils.lmul_utils import get_combined_offset
5
 
 
6
  ###########################
7
  # Old code below
8
  ###########################
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
 
10
 
11
  # BF16 Naive Combinatorial
 
1
+ from typing import Type
2
  import pyrtl
3
+ from pyrtl import WireVector
4
  from pyrtl.rtllib.adders import carrysave_adder, kogge_stone
5
 
6
+ from ..dtypes.base import BaseFloat
7
+
8
  from .utils.lmul_utils import get_combined_offset
9
 
10
+
11
  ###########################
12
  # Old code below
13
  ###########################
14
+ def lmul_simple(
15
+ float_a: WireVector,
16
+ float_b: WireVector,
17
+ dtype: Type[BaseFloat],
18
+ ):
19
+ """Linear time complexity float multiply unit in the simplest configuration."""
20
+ e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
21
+ em_bits = e_bits + m_bits
22
+ sign_out = float_a[em_bits] ^ float_b[em_bits]
23
+
24
+ unsigned_offset = pyrtl.Const(get_combined_offset(e_bits, m_bits), em_bits)
25
+ result_sum = float_a[:em_bits] + float_b[:em_bits] - unsigned_offset
26
+
27
+ fp_out = WireVector(bitwidth=em_bits + 1)
28
+ fp_out <<= pyrtl.concat(sign_out, pyrtl.truncate(result_sum, em_bits))
29
+ return fp_out
30
+
31
+
32
+ def lmul_fast(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
33
+ e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
34
+ em_bits = e_bits + m_bits
35
+ sign_a = float_a[em_bits]
36
+ sign_b = float_b[em_bits]
37
+ exp_mantissa_a = float_a[:em_bits]
38
+ exp_mantissa_b = float_b[:em_bits]
39
+ fp_out = WireVector(em_bits + 1)
40
+
41
+ # Calculate result sign
42
+ result_sign = sign_a ^ sign_b
43
+
44
+ # Add exp_mantissa parts using kogge_stone adder (faster than ripple)
45
+ # exp_mantissa_sum = kogge_stone(exp_mantissa_a, exp_mantissa_b)
46
+
47
+ # Get the combined offset-bias constant
48
+ OFFSET_MINUS_BIAS = pyrtl.Const(
49
+ get_combined_offset(e_bits, m_bits, True), bitwidth=em_bits
50
+ )
51
+
52
+ # Add offset-bias value - this will be 8 bits including carry
53
+ # final_sum = kogge_stone(exp_mantissa_sum, OFFSET_MINUS_BIAS)
54
+
55
+ final_sum = carrysave_adder(
56
+ exp_mantissa_a, exp_mantissa_b, OFFSET_MINUS_BIAS, final_adder=kogge_stone
57
+ )
58
+
59
+ # Select result based on carry and MSB:
60
+ # carry=1: overflow -> 0x7F
61
+ # carry=0, msb=0: underflow -> 0x00
62
+ # carry=0, msb=1: normal -> result_bits
63
+
64
+ MAX_VALUE = pyrtl.Const(2**em_bits - 1, bitwidth=em_bits) # , name="max_value")
65
+
66
+ if e_bits == 4 and m_bits == 3:
67
+ MAX_VALUE = pyrtl.Const(0x7F, 7)
68
+
69
+ mantissa_result = pyrtl.mux(
70
+ final_sum[em_bits:],
71
+ pyrtl.Const(0, bitwidth=em_bits),
72
+ final_sum[:em_bits],
73
+ default=MAX_VALUE,
74
+ )
75
+
76
+ # Combine sign and result
77
+ fp_out <<= pyrtl.concat(result_sign, mantissa_result)
78
+
79
+ return fp_out
80
 
81
 
82
  # BF16 Naive Combinatorial
hardware_accelerators/rtllib/lmul.py CHANGED
@@ -1,92 +1,92 @@
1
  from typing import Type
2
 
3
  import pyrtl
4
- from pyrtl import WireVector
5
- from pyrtl.rtllib.adders import carrysave_adder, kogge_stone
6
 
7
  from ..dtypes import BaseFloat, Float8
8
- from .utils.lmul_utils import get_combined_offset
9
 
10
 
11
- def lmul_simple(
12
- float_a: WireVector,
13
- float_b: WireVector,
14
- dtype: Type[BaseFloat],
15
- ):
16
- """Linear time complexity float multiply unit in the simplest configuration."""
17
- e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
18
- em_bits = e_bits + m_bits
19
- sign_out = float_a[em_bits] ^ float_b[em_bits]
20
-
21
- unsigned_offset = pyrtl.Const(get_combined_offset(e_bits, m_bits), em_bits)
22
- result_sum = float_a[:em_bits] + float_b[:em_bits] - unsigned_offset
23
-
24
- fp_out = WireVector(bitwidth=em_bits + 1)
25
- fp_out <<= pyrtl.concat(sign_out, pyrtl.truncate(result_sum, em_bits))
26
- return fp_out
27
-
28
-
29
- def lmul_fast(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
30
  e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
31
  em_bits = e_bits + m_bits
32
  sign_a = float_a[em_bits]
33
  sign_b = float_b[em_bits]
 
 
34
  exp_mantissa_a = float_a[:em_bits]
35
  exp_mantissa_b = float_b[:em_bits]
36
- fp_out = WireVector(em_bits + 1)
37
 
38
- # Calculate result sign
39
- result_sign = sign_a ^ sign_b
 
 
40
 
41
- # Add exp_mantissa parts using kogge_stone adder (faster than ripple)
42
- # exp_mantissa_sum = kogge_stone(exp_mantissa_a, exp_mantissa_b)
43
 
44
- # Get the combined offset-bias constant
45
- OFFSET_MINUS_BIAS = pyrtl.Const(
46
- get_combined_offset(e_bits, m_bits, True), bitwidth=em_bits
47
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
48
 
49
- # Add offset-bias value - this will be 8 bits including carry
50
- # final_sum = kogge_stone(exp_mantissa_sum, OFFSET_MINUS_BIAS)
51
 
52
- final_sum = carrysave_adder(
53
- exp_mantissa_a, exp_mantissa_b, OFFSET_MINUS_BIAS, final_adder=kogge_stone
54
- )
55
 
56
- # Select result based on carry and MSB:
57
- # carry=1: overflow -> 0x7F
58
- # carry=0, msb=0: underflow -> 0x00
59
- # carry=0, msb=1: normal -> result_bits
60
 
61
- MAX_VALUE = pyrtl.Const(2**em_bits - 1, bitwidth=em_bits) # , name="max_value")
62
 
63
- if e_bits == 4 and m_bits == 3:
64
- MAX_VALUE = pyrtl.Const(0x7F, 7)
65
 
66
- mantissa_result = pyrtl.mux(
67
- final_sum[em_bits:],
68
- pyrtl.Const(0, bitwidth=em_bits),
69
- final_sum[:em_bits],
70
- default=MAX_VALUE,
71
- )
72
 
73
- # Combine sign and result
74
- fp_out <<= pyrtl.concat(result_sign, mantissa_result)
 
 
 
 
 
75
 
76
- return fp_out
 
 
 
 
 
 
 
77
 
78
 
79
- # Float8 fast pipelined lmul
80
  class LmulPipelined:
81
  def __init__(
82
  self,
83
  float_a: WireVector,
84
  float_b: WireVector,
85
  dtype: Type[BaseFloat],
 
86
  ):
87
  self.e_bits = dtype.exponent_bits()
88
  self.m_bits = dtype.mantissa_bits()
89
  self.em_bits = dtype.bitwidth() - 1
 
90
 
91
  # Inputs and Outputs
92
  assert (
@@ -137,13 +137,16 @@ class LmulPipelined:
137
  # Calculate and register sign
138
  self.reg_sign.next <<= sign_a ^ sign_b
139
 
140
- # First addition and register result
141
- final_sum = carrysave_adder(
142
- exp_mantissa_a,
143
- exp_mantissa_b,
144
- self.OFFSET_MINUS_BIAS,
145
- final_adder=kogge_stone,
146
- )
 
 
 
147
 
148
  self.reg_final_sum.next <<= final_sum
149
 
 
1
  from typing import Type
2
 
3
  import pyrtl
4
+ from pyrtl import WireVector, conditional_assignment
5
+ from pyrtl.rtllib.adders import carrysave_adder, kogge_stone, fast_group_adder
6
 
7
  from ..dtypes import BaseFloat, Float8
8
+ from .utils.lmul_utils import get_combined_offset, lmul_offset_rtl
9
 
10
 
11
+ def lmul(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat], fast=False):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
13
  em_bits = e_bits + m_bits
14
  sign_a = float_a[em_bits]
15
  sign_b = float_b[em_bits]
16
+ exp_a = float_a[m_bits:-1]
17
+ exp_b = float_b[m_bits:-1]
18
  exp_mantissa_a = float_a[:em_bits]
19
  exp_mantissa_b = float_b[:em_bits]
 
20
 
21
+ zero_or_subnormal = WireVector(1)
22
+ final_sum = WireVector(em_bits + 2)
23
+ carry_msb = WireVector(2)
24
+ fp_out = WireVector(dtype.bitwidth())
25
 
26
+ OFFSET_MINUS_BIAS = lmul_offset_rtl(dtype)
27
+ MAX_VALUE = pyrtl.Const(dtype.binary_max(), bitwidth=em_bits)
28
 
29
+ if fast:
30
+ final_sum <<= carrysave_adder(
31
+ exp_mantissa_a, exp_mantissa_b, OFFSET_MINUS_BIAS, final_adder=kogge_stone
32
+ )
33
+ else:
34
+ final_sum <<= exp_mantissa_a + exp_mantissa_b + OFFSET_MINUS_BIAS
35
+
36
+ carry_msb <<= final_sum[em_bits:]
37
+ zero_or_subnormal <<= ~pyrtl.or_all_bits(exp_a) | ~pyrtl.or_all_bits(exp_b)
38
+
39
+ with conditional_assignment:
40
+ with zero_or_subnormal:
41
+ fp_out |= 0
42
+ with carry_msb == 0:
43
+ fp_out |= 0
44
+ with carry_msb == 1:
45
+ fp_out |= pyrtl.concat(sign_a ^ sign_b, final_sum[:em_bits])
46
+ with pyrtl.otherwise:
47
+ fp_out |= pyrtl.concat(sign_a ^ sign_b, MAX_VALUE)
48
 
49
+ return fp_out
 
50
 
 
 
 
51
 
52
+ def lmul_simple(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
53
+ return lmul(float_a, float_b, dtype, fast=False)
 
 
54
 
 
55
 
56
+ def lmul_fast(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
57
+ return lmul(float_a, float_b, dtype, fast=True)
58
 
 
 
 
 
 
 
59
 
60
+ def lmul_pipelined(
61
+ float_a: WireVector,
62
+ float_b: WireVector,
63
+ dtype: Type[BaseFloat],
64
+ ) -> WireVector:
65
+ mult = LmulPipelined(float_a, float_b, dtype)
66
+ return mult.output_reg
67
 
68
+
69
+ def lmul_pipelined_fast(
70
+ float_a: WireVector,
71
+ float_b: WireVector,
72
+ dtype: Type[BaseFloat],
73
+ ) -> WireVector:
74
+ mult = LmulPipelined(float_a, float_b, dtype, fast=True)
75
+ return mult.output_reg
76
 
77
 
 
78
  class LmulPipelined:
79
  def __init__(
80
  self,
81
  float_a: WireVector,
82
  float_b: WireVector,
83
  dtype: Type[BaseFloat],
84
+ fast: bool = False,
85
  ):
86
  self.e_bits = dtype.exponent_bits()
87
  self.m_bits = dtype.mantissa_bits()
88
  self.em_bits = dtype.bitwidth() - 1
89
+ self._fast = fast
90
 
91
  # Inputs and Outputs
92
  assert (
 
137
  # Calculate and register sign
138
  self.reg_sign.next <<= sign_a ^ sign_b
139
 
140
+ # Add the floating point numbers with special lmul offset
141
+ if self._fast:
142
+ final_sum = carrysave_adder(
143
+ exp_mantissa_a,
144
+ exp_mantissa_b,
145
+ self.OFFSET_MINUS_BIAS,
146
+ final_adder=kogge_stone,
147
+ )
148
+ else:
149
+ final_sum = exp_mantissa_a + exp_mantissa_b + self.OFFSET_MINUS_BIAS
150
 
151
  self.reg_final_sum.next <<= final_sum
152