Justin Chou
commited on
Commit
·
679abc4
1
Parent(s):
e9c2b75
yeah
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- Dockerfile +14 -4
- demo/app.py +0 -312
- demo/readme.md +0 -9
- demo/requirements.txt +0 -17
- hardware_accelerators/__init__.py +16 -1
- hardware_accelerators/analysis/__init__.py +0 -0
- hardware_accelerators/analysis/config.py +25 -0
- hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/config.mk +10 -0
- hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/constraint.sdc +1 -0
- hardware_accelerators/analysis/generate.py +958 -0
- hardware_accelerators/analysis/hardware_stats.py +458 -0
- hardware_accelerators/analysis/mnist_eval.py +274 -0
- hardware_accelerators/analysis/simple_circuits.py +258 -0
- hardware_accelerators/analysis/verilog_export.py +86 -0
- hardware_accelerators/analysis/verilog_output/pipelined_adder_BF16.v +37 -0
- hardware_accelerators/analysis/verilog_output/pipelined_adder_Float16.v +37 -0
- hardware_accelerators/analysis/verilog_output/pipelined_adder_Float32.v +37 -0
- hardware_accelerators/analysis/verilog_output/pipelined_adder_Float8.v +37 -0
- hardware_accelerators/analysis/verilog_output/pipelined_multiplier_BF16.v +37 -0
- hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float16.v +37 -0
- hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float32.v +37 -0
- hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float8.v +37 -0
- hardware_accelerators/analysis/verilog_output/simple_adder_BF16.v +21 -0
- hardware_accelerators/analysis/verilog_output/simple_adder_Float16.v +21 -0
- hardware_accelerators/analysis/verilog_output/simple_adder_Float32.v +21 -0
- hardware_accelerators/analysis/verilog_output/simple_adder_Float8.v +21 -0
- hardware_accelerators/analysis/verilog_output/simple_multiplier_BF16.v +21 -0
- hardware_accelerators/analysis/verilog_output/simple_multiplier_Float16.v +21 -0
- hardware_accelerators/analysis/verilog_output/simple_multiplier_Float32.v +21 -0
- hardware_accelerators/analysis/verilog_output/simple_multiplier_Float8.v +21 -0
- hardware_accelerators/app.py +388 -0
- hardware_accelerators/compile.py +167 -0
- hardware_accelerators/dtypes/__init__.py +3 -1
- hardware_accelerators/dtypes/base.py +12 -3
- hardware_accelerators/dtypes/bfloat16.py +4 -0
- hardware_accelerators/dtypes/float16.py +167 -0
- hardware_accelerators/dtypes/float32.py +174 -0
- hardware_accelerators/dtypes/float8.py +4 -0
- hardware_accelerators/nn/lmul.py +135 -0
- hardware_accelerators/nn/precision.py +264 -0
- hardware_accelerators/nn/precision_eval.py +280 -0
- hardware_accelerators/nn/run_precision_comparison.py +78 -0
- hardware_accelerators/nn/train.py +0 -2
- hardware_accelerators/nn/util.py +3 -1
- hardware_accelerators/rtllib/__init__.py +10 -2
- hardware_accelerators/rtllib/accelerator.py +407 -113
- hardware_accelerators/rtllib/activations.py +69 -7
- hardware_accelerators/rtllib/adders.py +54 -8
- hardware_accelerators/rtllib/legacy.py +71 -0
- hardware_accelerators/rtllib/lmul.py +63 -60
Dockerfile
CHANGED
|
@@ -1,3 +1,4 @@
|
|
|
|
|
| 1 |
FROM python:3.12-slim
|
| 2 |
|
| 3 |
WORKDIR /code
|
|
@@ -9,20 +10,29 @@ RUN apt-get update && apt-get install -y \
|
|
| 9 |
&& rm -rf /var/lib/apt/lists/*
|
| 10 |
|
| 11 |
# Install Python packages
|
| 12 |
-
COPY
|
| 13 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 14 |
|
| 15 |
# Copy the model and application files
|
| 16 |
COPY models/ /code/models/
|
| 17 |
COPY hardware_accelerators/ /code/hardware_accelerators/
|
| 18 |
-
COPY
|
| 19 |
|
| 20 |
-
|
|
|
|
|
|
|
|
|
|
| 21 |
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 22 |
ENV GRADIO_SERVER_PORT=7860
|
| 23 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 24 |
# Expose the port Gradio runs on
|
| 25 |
EXPOSE 7860
|
| 26 |
|
| 27 |
# Command to run the Gradio app
|
| 28 |
-
CMD ["
|
|
|
|
| 1 |
+
# Dockerfile for the demo
|
| 2 |
FROM python:3.12-slim
|
| 3 |
|
| 4 |
WORKDIR /code
|
|
|
|
| 10 |
&& rm -rf /var/lib/apt/lists/*
|
| 11 |
|
| 12 |
# Install Python packages
|
| 13 |
+
COPY ./requirements.txt requirements.txt
|
| 14 |
RUN pip install --no-cache-dir -r requirements.txt
|
| 15 |
|
| 16 |
# Copy the model and application files
|
| 17 |
COPY models/ /code/models/
|
| 18 |
COPY hardware_accelerators/ /code/hardware_accelerators/
|
| 19 |
+
COPY results/component_data.csv /code/data/
|
| 20 |
|
| 21 |
+
|
| 22 |
+
# Set environment variables
|
| 23 |
+
ENV HWA_CACHE_DIR=/code/hardware_accelerators/cache
|
| 24 |
+
ENV COMPONENT_DATA_PATH=/code/data/component_data.csv
|
| 25 |
ENV GRADIO_SERVER_NAME=0.0.0.0
|
| 26 |
ENV GRADIO_SERVER_PORT=7860
|
| 27 |
|
| 28 |
+
# Copy the component data
|
| 29 |
+
COPY results/component_data.csv /code/data/component_data.csv
|
| 30 |
+
|
| 31 |
+
# Compile the simulations
|
| 32 |
+
RUN python3 -m hardware_accelerators.compile
|
| 33 |
+
|
| 34 |
# Expose the port Gradio runs on
|
| 35 |
EXPOSE 7860
|
| 36 |
|
| 37 |
# Command to run the Gradio app
|
| 38 |
+
CMD ["python3", "-m", "hardware_accelerators.app"]
|
demo/app.py
DELETED
|
@@ -1,312 +0,0 @@
|
|
| 1 |
-
import sys
|
| 2 |
-
import gradio as gr
|
| 3 |
-
from gradio.components.image_editor import EditorValue
|
| 4 |
-
import numpy as np
|
| 5 |
-
import torch
|
| 6 |
-
import pandas as pd
|
| 7 |
-
from PIL import Image
|
| 8 |
-
import torch
|
| 9 |
-
import torch.nn as nn
|
| 10 |
-
import torchvision.transforms as transforms
|
| 11 |
-
import tqdm
|
| 12 |
-
import pyrtl
|
| 13 |
-
|
| 14 |
-
sys.path.append(".")
|
| 15 |
-
from hardware_accelerators.nn.util import softmax
|
| 16 |
-
from hardware_accelerators.simulation.matrix_utils import (
|
| 17 |
-
bias_trick,
|
| 18 |
-
count_total_gemv_tiles,
|
| 19 |
-
generate_gemv_tiles,
|
| 20 |
-
)
|
| 21 |
-
from hardware_accelerators.rtllib.adders import float_adder
|
| 22 |
-
from hardware_accelerators.rtllib.multipliers import float_multiplier
|
| 23 |
-
from hardware_accelerators.dtypes.bfloat16 import BF16
|
| 24 |
-
from hardware_accelerators.dtypes.float8 import Float8
|
| 25 |
-
from hardware_accelerators.nn import model_factory, get_pytorch_device
|
| 26 |
-
from hardware_accelerators.rtllib import (
|
| 27 |
-
AcceleratorConfig,
|
| 28 |
-
Accelerator,
|
| 29 |
-
lmul_fast,
|
| 30 |
-
float_multiplier,
|
| 31 |
-
)
|
| 32 |
-
from hardware_accelerators.simulation import CompiledSimulator, AcceleratorSimulator
|
| 33 |
-
|
| 34 |
-
|
| 35 |
-
# ------------ CONSTANTS ------------ #
|
| 36 |
-
|
| 37 |
-
# Load the trained model
|
| 38 |
-
model_path = "models/mlp_mnist.pth"
|
| 39 |
-
model = model_factory()
|
| 40 |
-
model.load_state_dict(
|
| 41 |
-
torch.load(model_path, map_location=get_pytorch_device(), weights_only=True)
|
| 42 |
-
)
|
| 43 |
-
model.eval()
|
| 44 |
-
|
| 45 |
-
classes = [
|
| 46 |
-
"zero",
|
| 47 |
-
"one",
|
| 48 |
-
"two",
|
| 49 |
-
"three",
|
| 50 |
-
"four",
|
| 51 |
-
"five",
|
| 52 |
-
"six",
|
| 53 |
-
"seven",
|
| 54 |
-
"eight",
|
| 55 |
-
"nine",
|
| 56 |
-
]
|
| 57 |
-
labels_value = {label: 0.0 for label in classes}
|
| 58 |
-
|
| 59 |
-
accelerator_dtypes = ["float8", "bfloat16"]
|
| 60 |
-
# accelerator_dtypes = ["float8", "float16", "bfloat16", "float32"]
|
| 61 |
-
|
| 62 |
-
dtype_map = {"float8": Float8, "bfloat16": BF16}
|
| 63 |
-
|
| 64 |
-
default_config = {
|
| 65 |
-
"activations_dtype": "bfloat16",
|
| 66 |
-
"weights_dtype": "bfloat16",
|
| 67 |
-
"size": 4,
|
| 68 |
-
"multiplication": "IEEE 754",
|
| 69 |
-
}
|
| 70 |
-
|
| 71 |
-
mult_map = {
|
| 72 |
-
"IEEE 754": float_multiplier,
|
| 73 |
-
"l-mul": lmul_fast,
|
| 74 |
-
}
|
| 75 |
-
|
| 76 |
-
|
| 77 |
-
# ------------ Event Listener Functions ------------ #
|
| 78 |
-
|
| 79 |
-
|
| 80 |
-
def image_to_tensor(sketchpad: EditorValue):
|
| 81 |
-
image = sketchpad["composite"]
|
| 82 |
-
image = image.resize((28, 28), Image.Resampling.LANCZOS) # type: ignore
|
| 83 |
-
img_array = np.transpose(np.array(image), (2, 0, 1))[-1]
|
| 84 |
-
|
| 85 |
-
# Preprocessing: convert image to tensor and normalize
|
| 86 |
-
transform = transforms.Compose(
|
| 87 |
-
[
|
| 88 |
-
transforms.ToTensor(),
|
| 89 |
-
transforms.Normalize((0.1307,), (0.3081,)),
|
| 90 |
-
]
|
| 91 |
-
)
|
| 92 |
-
tensor_image = transform(img_array).unsqueeze(0) # Add batch dimension
|
| 93 |
-
return tensor_image
|
| 94 |
-
|
| 95 |
-
|
| 96 |
-
def torch_predict(sketchpad: EditorValue):
|
| 97 |
-
tensor_image = image_to_tensor(sketchpad)
|
| 98 |
-
with torch.no_grad():
|
| 99 |
-
logits = model(tensor_image)
|
| 100 |
-
probabilities = torch.softmax(logits, dim=1).squeeze(0)
|
| 101 |
-
result = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
|
| 102 |
-
return result
|
| 103 |
-
|
| 104 |
-
|
| 105 |
-
def update_accelerator_config(
|
| 106 |
-
activations_dtype: str, weights_dtype: str, size: int, multiplication: str
|
| 107 |
-
) -> AcceleratorConfig:
|
| 108 |
-
|
| 109 |
-
# Triggered by run simulation button
|
| 110 |
-
print("update_accelerator_config fn called")
|
| 111 |
-
print(activations_dtype, weights_dtype, size, multiplication)
|
| 112 |
-
|
| 113 |
-
return AcceleratorConfig(
|
| 114 |
-
num_weight_tiles=4,
|
| 115 |
-
weight_type=dtype_map[weights_dtype],
|
| 116 |
-
data_type=dtype_map[activations_dtype],
|
| 117 |
-
array_size=size,
|
| 118 |
-
pe_multiplier=mult_map[multiplication],
|
| 119 |
-
pe_adder=float_adder,
|
| 120 |
-
accum_adder=float_adder,
|
| 121 |
-
accum_addr_width=8,
|
| 122 |
-
accum_type=dtype_map[activations_dtype],
|
| 123 |
-
pipeline=False,
|
| 124 |
-
)
|
| 125 |
-
|
| 126 |
-
|
| 127 |
-
def simulator_predict(sketchpad: EditorValue, config: AcceleratorConfig):
|
| 128 |
-
# if config == DEFAULT_ACCELERATOR_CONFIG:
|
| 129 |
-
# sim = ACCELERATOR_SIM
|
| 130 |
-
# else:
|
| 131 |
-
# sim = AcceleratorSimulator(config=config)
|
| 132 |
-
|
| 133 |
-
sim = CompiledSimulator(config=config)
|
| 134 |
-
image = image_to_tensor(sketchpad).detach().numpy().flatten()
|
| 135 |
-
probabilities = sim.run_mlp(model, image)
|
| 136 |
-
result = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
|
| 137 |
-
return result
|
| 138 |
-
|
| 139 |
-
|
| 140 |
-
def sim_predict_progress(
|
| 141 |
-
sketchpad: EditorValue,
|
| 142 |
-
config: AcceleratorConfig,
|
| 143 |
-
gr_progress=gr.Progress(track_tqdm=True),
|
| 144 |
-
):
|
| 145 |
-
pyrtl.reset_working_block()
|
| 146 |
-
simulator = CompiledSimulator(config=config)
|
| 147 |
-
chunk_size = config.array_size
|
| 148 |
-
|
| 149 |
-
x = image_to_tensor(sketchpad).detach().numpy().flatten()
|
| 150 |
-
probabilities = simulator.run_mlp(model, x)
|
| 151 |
-
return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
|
| 152 |
-
|
| 153 |
-
weights_1 = model.fc1.weight.numpy(force=True)
|
| 154 |
-
bias_1 = model.fc1.bias.numpy(force=True)
|
| 155 |
-
weights_2 = model.fc2.weight.numpy(force=True)
|
| 156 |
-
bias_2 = model.fc2.bias.numpy(force=True)
|
| 157 |
-
|
| 158 |
-
# Add bias to first layer weights and 1 to activations
|
| 159 |
-
W_aug, x_aug = bias_trick(weights_1, bias_1, x)
|
| 160 |
-
|
| 161 |
-
total_tiles = count_total_gemv_tiles([(784, 128), (128, 10)], chunk_size)
|
| 162 |
-
progress = tqdm.tqdm(total=total_tiles)
|
| 163 |
-
|
| 164 |
-
tile_generator = generate_gemv_tiles(x_aug, W_aug, chunk_size)
|
| 165 |
-
|
| 166 |
-
for tile in tile_generator:
|
| 167 |
-
simulator.load_weights(weights=tile.matrix.T, tile_addr=0)
|
| 168 |
-
simulator.execute_instruction(
|
| 169 |
-
load_new_weights=True,
|
| 170 |
-
weight_tile_addr=0,
|
| 171 |
-
data_vec=tile.vector,
|
| 172 |
-
accum_addr=tile.index,
|
| 173 |
-
accum_mode=not tile.first,
|
| 174 |
-
activation_func="relu",
|
| 175 |
-
activation_enable=tile.last,
|
| 176 |
-
flush_pipeline=True,
|
| 177 |
-
)
|
| 178 |
-
progress.update()
|
| 179 |
-
|
| 180 |
-
simulator.execute_instruction(nop=True)
|
| 181 |
-
simulator.execute_instruction(nop=True)
|
| 182 |
-
|
| 183 |
-
sim_fc1 = np.array(simulator.output_trace)
|
| 184 |
-
|
| 185 |
-
# simulator.reset_output_trace()
|
| 186 |
-
simulator.output_trace = []
|
| 187 |
-
|
| 188 |
-
W2_aug, fc1_aug = bias_trick(weights_2, bias_2, sim_fc1.flatten())
|
| 189 |
-
|
| 190 |
-
fc2_tile_generator = generate_gemv_tiles(fc1_aug, W2_aug, chunk_size)
|
| 191 |
-
|
| 192 |
-
for tile in fc2_tile_generator:
|
| 193 |
-
simulator.load_weights(weights=tile.matrix.T, tile_addr=0)
|
| 194 |
-
simulator.execute_instruction(
|
| 195 |
-
load_new_weights=True,
|
| 196 |
-
weight_tile_addr=0,
|
| 197 |
-
data_vec=tile.vector,
|
| 198 |
-
accum_addr=tile.index,
|
| 199 |
-
accum_mode=not tile.first,
|
| 200 |
-
activation_enable=tile.last,
|
| 201 |
-
flush_pipeline=True,
|
| 202 |
-
)
|
| 203 |
-
progress.update()
|
| 204 |
-
|
| 205 |
-
simulator.execute_instruction(nop=True)
|
| 206 |
-
simulator.execute_instruction(nop=True)
|
| 207 |
-
|
| 208 |
-
sim_fc2 = np.array(simulator.output_trace).flatten()
|
| 209 |
-
probabilities = softmax(sim_fc2)
|
| 210 |
-
result = {cls: float(prob) for cls, prob in zip(classes, probabilities)}
|
| 211 |
-
return result
|
| 212 |
-
|
| 213 |
-
|
| 214 |
-
# ------------ Blocks UI Layout ------------ #
|
| 215 |
-
|
| 216 |
-
with gr.Blocks(fill_height=False) as demo:
|
| 217 |
-
|
| 218 |
-
accelerator_config = gr.State()
|
| 219 |
-
|
| 220 |
-
gr.Markdown("## Draw a digit to see the model's prediction")
|
| 221 |
-
with gr.Row(equal_height=True):
|
| 222 |
-
with gr.Column():
|
| 223 |
-
sketchpad = gr.Sketchpad(
|
| 224 |
-
# label="Draw a digit",
|
| 225 |
-
type="pil", # Changed to PIL
|
| 226 |
-
transforms=(),
|
| 227 |
-
layers=False,
|
| 228 |
-
canvas_size=(400, 400),
|
| 229 |
-
)
|
| 230 |
-
|
| 231 |
-
with gr.Row():
|
| 232 |
-
predict_btn = gr.Button("Run Hardware Simulation", variant="primary")
|
| 233 |
-
|
| 234 |
-
# with gr.Accordion("Accelerator Configuration", open=True):
|
| 235 |
-
with gr.Group():
|
| 236 |
-
weight_dtype_component = gr.Radio(
|
| 237 |
-
label="Weights d-type",
|
| 238 |
-
choices=accelerator_dtypes,
|
| 239 |
-
value=default_config["weights_dtype"],
|
| 240 |
-
interactive=True,
|
| 241 |
-
)
|
| 242 |
-
activation_dtype_component = gr.Radio(
|
| 243 |
-
label="Activations d-type",
|
| 244 |
-
choices=accelerator_dtypes,
|
| 245 |
-
value=default_config["activations_dtype"],
|
| 246 |
-
interactive=True,
|
| 247 |
-
)
|
| 248 |
-
systolic_array_size_component = gr.Slider(
|
| 249 |
-
label="Systolic Array Size",
|
| 250 |
-
info="Large values will significantly slow down the simulation",
|
| 251 |
-
minimum=2,
|
| 252 |
-
maximum=16,
|
| 253 |
-
step=1,
|
| 254 |
-
value=default_config["size"],
|
| 255 |
-
interactive=True,
|
| 256 |
-
)
|
| 257 |
-
multiply_component = gr.Radio(
|
| 258 |
-
label="Multiplication Type",
|
| 259 |
-
choices=["IEEE 754", "l-mul"],
|
| 260 |
-
value=default_config["multiplication"],
|
| 261 |
-
interactive=True,
|
| 262 |
-
)
|
| 263 |
-
|
| 264 |
-
with gr.Column():
|
| 265 |
-
pytorch_output = gr.Label(
|
| 266 |
-
label="Pytorch Ground Truth Predictions", value=labels_value
|
| 267 |
-
)
|
| 268 |
-
|
| 269 |
-
sim_output = gr.Label(
|
| 270 |
-
label="Hardware Simulator Predictions", value=labels_value
|
| 271 |
-
)
|
| 272 |
-
|
| 273 |
-
# ------------ Event Listeners ------------ #
|
| 274 |
-
|
| 275 |
-
sketchpad.input(
|
| 276 |
-
fn=torch_predict,
|
| 277 |
-
inputs=sketchpad,
|
| 278 |
-
outputs=pytorch_output,
|
| 279 |
-
)
|
| 280 |
-
|
| 281 |
-
# TODO: implement simulator_predict
|
| 282 |
-
predict_btn.click(
|
| 283 |
-
fn=update_accelerator_config,
|
| 284 |
-
inputs=[
|
| 285 |
-
activation_dtype_component,
|
| 286 |
-
weight_dtype_component,
|
| 287 |
-
systolic_array_size_component,
|
| 288 |
-
multiply_component,
|
| 289 |
-
],
|
| 290 |
-
outputs=accelerator_config,
|
| 291 |
-
).then(
|
| 292 |
-
fn=sim_predict_progress,
|
| 293 |
-
inputs=[sketchpad, accelerator_config],
|
| 294 |
-
outputs=sim_output,
|
| 295 |
-
)
|
| 296 |
-
|
| 297 |
-
# gr.on(
|
| 298 |
-
# fn=update_accelerator_config,
|
| 299 |
-
# inputs=[
|
| 300 |
-
# activation_dtype_component,
|
| 301 |
-
# weight_dtype_component,
|
| 302 |
-
# systolic_array_size_component,
|
| 303 |
-
# multiply_component,
|
| 304 |
-
# ],
|
| 305 |
-
# outputs=accelerator_config,
|
| 306 |
-
# )
|
| 307 |
-
|
| 308 |
-
# ------------
|
| 309 |
-
|
| 310 |
-
if __name__ == "__main__":
|
| 311 |
-
demo.queue()
|
| 312 |
-
demo.launch(share=False)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/readme.md
DELETED
|
@@ -1,9 +0,0 @@
|
|
| 1 |
-
# Interactive Demo
|
| 2 |
-
|
| 3 |
-
This directory contains a demo where you can test out and compare the performance of the hardware accelerator with a software implementation of the same model. The demo is built using [Gradio](https://gradio.app/), a Python library for creating interactive web applications.
|
| 4 |
-
|
| 5 |
-
## Running the Demo
|
| 6 |
-
|
| 7 |
-
Install the project requirements, then from the repo root simply run `python demo/app.py` to start the demo.
|
| 8 |
-
|
| 9 |
-
We will also provide a ready to go Docker image soon!
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
demo/requirements.txt
DELETED
|
@@ -1,17 +0,0 @@
|
|
| 1 |
-
jupyter==1.1.1
|
| 2 |
-
ipykernel==6.29.5
|
| 3 |
-
tqdm==4.67.0
|
| 4 |
-
numpy==2.2.1
|
| 5 |
-
ipython==8.12.3
|
| 6 |
-
isort==5.13.2
|
| 7 |
-
numpy==2.2.1
|
| 8 |
-
pandas==2.2.3
|
| 9 |
-
pyrtl==0.11.2
|
| 10 |
-
matplotlib==3.10.0
|
| 11 |
-
pytest==8.3.4
|
| 12 |
-
torch==2.4.1
|
| 13 |
-
torchvision==0.19.1
|
| 14 |
-
onnx==1.17.0
|
| 15 |
-
netron==8.1.3
|
| 16 |
-
gradio==5.16.0
|
| 17 |
-
black[jupyter]==24.10.0
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
hardware_accelerators/__init__.py
CHANGED
|
@@ -1,4 +1,8 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
|
|
|
|
|
|
| 2 |
from .rtllib import (
|
| 3 |
FloatAdderPipelined,
|
| 4 |
FloatMultiplierPipelined,
|
|
@@ -8,10 +12,21 @@ from .rtllib import (
|
|
| 8 |
lmul_fast,
|
| 9 |
lmul_simple,
|
| 10 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 11 |
|
| 12 |
__all__ = [
|
|
|
|
|
|
|
|
|
|
| 13 |
"Float8",
|
| 14 |
"BF16",
|
|
|
|
|
|
|
| 15 |
"float_adder",
|
| 16 |
"FloatAdderPipelined",
|
| 17 |
"float_multiplier",
|
|
|
|
| 1 |
+
from dotenv import load_dotenv
|
| 2 |
+
|
| 3 |
+
load_dotenv()
|
| 4 |
+
|
| 5 |
+
from .dtypes import BF16, Float8, Float16, Float32
|
| 6 |
from .rtllib import (
|
| 7 |
FloatAdderPipelined,
|
| 8 |
FloatMultiplierPipelined,
|
|
|
|
| 12 |
lmul_fast,
|
| 13 |
lmul_simple,
|
| 14 |
)
|
| 15 |
+
from .simulation import (
|
| 16 |
+
get_sim_cache_dir,
|
| 17 |
+
set_sim_cache_dir,
|
| 18 |
+
CompiledAcceleratorSimulator,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
|
| 22 |
__all__ = [
|
| 23 |
+
"get_sim_cache_dir",
|
| 24 |
+
"set_sim_cache_dir",
|
| 25 |
+
"CompiledAcceleratorSimulator",
|
| 26 |
"Float8",
|
| 27 |
"BF16",
|
| 28 |
+
"Float16",
|
| 29 |
+
"Float32",
|
| 30 |
"float_adder",
|
| 31 |
"FloatAdderPipelined",
|
| 32 |
"float_multiplier",
|
hardware_accelerators/analysis/__init__.py
ADDED
|
File without changes
|
hardware_accelerators/analysis/config.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from ..dtypes import *
|
| 2 |
+
from ..rtllib.lmul import lmul_fast, lmul_simple
|
| 3 |
+
from ..rtllib.multipliers import float_multiplier
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
NN_TEST_BATCH_SIZE = 64
|
| 7 |
+
|
| 8 |
+
NN_TEST_SYSTOLIC_ARRAY_SIZE = 8
|
| 9 |
+
|
| 10 |
+
NN_TEST_ACCUM_ADDR_WIDTH = 12
|
| 11 |
+
|
| 12 |
+
NN_TEST_MUL_FNS = [
|
| 13 |
+
float_multiplier,
|
| 14 |
+
lmul_simple,
|
| 15 |
+
# lmul_fast,
|
| 16 |
+
]
|
| 17 |
+
|
| 18 |
+
NN_TEST_WA_DTYPES = [
|
| 19 |
+
# (Float8, Float8),
|
| 20 |
+
(Float8, BF16),
|
| 21 |
+
(Float8, Float32),
|
| 22 |
+
(BF16, BF16),
|
| 23 |
+
# (BF16, Float32),
|
| 24 |
+
# (Float32, Float32),
|
| 25 |
+
]
|
hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/config.mk
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
export DESIGN_NAME = lmul_pipelined_fast
|
| 2 |
+
export PLATFORM = nangate45
|
| 3 |
+
export VERILOG_FILES = $(DESIGN_DIR)/src/lmul_pipelined_fast.v
|
| 4 |
+
export SDC_FILE = $(DESIGN_DIR)/constraint.sdc
|
| 5 |
+
|
| 6 |
+
# These values must be multiples of placement site
|
| 7 |
+
export DIE_AREA = 0 0 100 100
|
| 8 |
+
export CORE_AREA = 10 10 90 90
|
| 9 |
+
|
| 10 |
+
export CLOCK_PERIOD = 1.0
|
hardware_accelerators/analysis/flow/designs/sky130hd/mydesign/constraint.sdc
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
create_clock -name clk -period 1.0 [get_ports {clk}]
|
hardware_accelerators/analysis/generate.py
ADDED
|
@@ -0,0 +1,958 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from pathlib import Path
|
| 2 |
+
from h11 import Data
|
| 3 |
+
from pandas import DataFrame
|
| 4 |
+
import pyrtl
|
| 5 |
+
from itertools import product
|
| 6 |
+
from pyrtl import *
|
| 7 |
+
from dataclasses import dataclass
|
| 8 |
+
from typing import Callable, Type, Literal, Optional
|
| 9 |
+
|
| 10 |
+
from .verilog_export import export_to_verilog
|
| 11 |
+
from ..dtypes import *
|
| 12 |
+
from ..rtllib import *
|
| 13 |
+
from ..rtllib.processing_element import ProcessingElement
|
| 14 |
+
from ..rtllib.adders import *
|
| 15 |
+
from ..rtllib.multipliers import *
|
| 16 |
+
from ..rtllib.lmul import *
|
| 17 |
+
from ..rtllib.utils.common import *
|
| 18 |
+
from ..simulation.utils import *
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def create_inputs(**named_bitwidths):
|
| 22 |
+
"""
|
| 23 |
+
Create PyRTL Input wires with specified bitwidths.
|
| 24 |
+
|
| 25 |
+
Args:
|
| 26 |
+
**named_bitwidths: Named bitwidths where the key is used as the wire name
|
| 27 |
+
|
| 28 |
+
Returns:
|
| 29 |
+
Generator of PyRTL Input wires
|
| 30 |
+
|
| 31 |
+
Note:
|
| 32 |
+
You must use all keyword arguments
|
| 33 |
+
"""
|
| 34 |
+
|
| 35 |
+
# If using keyword arguments
|
| 36 |
+
for name, bitwidth in named_bitwidths.items():
|
| 37 |
+
yield pyrtl.Input(bitwidth, name=name) # type: ignore
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def create_outputs(*args, **named_wires):
|
| 41 |
+
"""
|
| 42 |
+
Create PyRTL Output wires connected to the input wires.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
*args: Variable number of wires to connect to unnamed outputs
|
| 46 |
+
**named_wires: Named wires where the key is used as the output wire name
|
| 47 |
+
|
| 48 |
+
Note:
|
| 49 |
+
You must use either all positional arguments or all keyword arguments, not a mix.
|
| 50 |
+
"""
|
| 51 |
+
if args and named_wires:
|
| 52 |
+
raise ValueError(
|
| 53 |
+
"Please use either all positional arguments or all keyword arguments, not a mix."
|
| 54 |
+
)
|
| 55 |
+
|
| 56 |
+
# If using positional arguments
|
| 57 |
+
for wire in args:
|
| 58 |
+
out = pyrtl.Output(len(wire), name=wire.name.replace("tmp", "out")) # type: ignore
|
| 59 |
+
out <<= wire
|
| 60 |
+
|
| 61 |
+
# If using keyword arguments
|
| 62 |
+
for name, wire in named_wires.items():
|
| 63 |
+
out = pyrtl.Output(len(wire), name=name) # type: ignore
|
| 64 |
+
out <<= wire
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
@dataclass
|
| 68 |
+
class RTLAnalysis:
|
| 69 |
+
"""Results of RTL analysis."""
|
| 70 |
+
|
| 71 |
+
max_delay: float
|
| 72 |
+
max_freq: float
|
| 73 |
+
logic_area: float
|
| 74 |
+
mem_area: float
|
| 75 |
+
name: Optional[str] = None
|
| 76 |
+
|
| 77 |
+
def __repr__(self):
|
| 78 |
+
if self.name is None:
|
| 79 |
+
return (
|
| 80 |
+
f"RTLAnalysisResults("
|
| 81 |
+
f"max_delay={self.max_delay:.2f} ps, "
|
| 82 |
+
f"max_freq={self.max_freq:.2f} MHz, "
|
| 83 |
+
f"logic_area={self.logic_area:.2f}um², "
|
| 84 |
+
f"mem_area={self.mem_area:.2f}um²)"
|
| 85 |
+
)
|
| 86 |
+
else:
|
| 87 |
+
return (
|
| 88 |
+
f"RTLAnalysisResults for {self.name}:\n\t"
|
| 89 |
+
f"max_delay={self.max_delay:.2f} ps\n\t"
|
| 90 |
+
f"max_freq={self.max_freq:.2f} MHz\n\t"
|
| 91 |
+
f"logic_area={self.logic_area:.2f}um²\n\t"
|
| 92 |
+
f"mem_area={self.mem_area:.2f}um²"
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
def analyze(
|
| 97 |
+
block: Block | None = None, synth: bool = True, opt: bool = True, name=None
|
| 98 |
+
):
|
| 99 |
+
if block is not None:
|
| 100 |
+
pyrtl.set_working_block(block)
|
| 101 |
+
|
| 102 |
+
if synth:
|
| 103 |
+
pyrtl.synthesize()
|
| 104 |
+
if opt:
|
| 105 |
+
pyrtl.optimize()
|
| 106 |
+
|
| 107 |
+
timing = pyrtl.TimingAnalysis()
|
| 108 |
+
max_delay = timing.max_length()
|
| 109 |
+
max_freq = timing.max_freq()
|
| 110 |
+
logic_area, mem_area = pyrtl.area_estimation()
|
| 111 |
+
|
| 112 |
+
return RTLAnalysis(
|
| 113 |
+
name=name,
|
| 114 |
+
max_delay=max_delay,
|
| 115 |
+
max_freq=max_freq,
|
| 116 |
+
logic_area=logic_area * 1e6,
|
| 117 |
+
mem_area=mem_area * 1e6,
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
def create_adder_blocks(dtype: Type[BaseFloat]) -> dict[str, Block]:
|
| 122 |
+
bits = dtype.bitwidth()
|
| 123 |
+
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 124 |
+
|
| 125 |
+
combinational_block = pyrtl.Block()
|
| 126 |
+
combinational_fast_block = pyrtl.Block()
|
| 127 |
+
adder_pipelined_block = pyrtl.Block()
|
| 128 |
+
adder_pipelined_fast_block = pyrtl.Block()
|
| 129 |
+
stage_2_block = pyrtl.Block()
|
| 130 |
+
stage_2_fast_block = pyrtl.Block()
|
| 131 |
+
stage_3_block = pyrtl.Block()
|
| 132 |
+
stage_4_block = pyrtl.Block()
|
| 133 |
+
stage_4_fast_block = pyrtl.Block()
|
| 134 |
+
stage_5_block = pyrtl.Block()
|
| 135 |
+
|
| 136 |
+
# Combinational design
|
| 137 |
+
with set_working_block(combinational_block):
|
| 138 |
+
create_outputs(
|
| 139 |
+
*float_adder(
|
| 140 |
+
*create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=False
|
| 141 |
+
)
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
with set_working_block(combinational_fast_block):
|
| 145 |
+
create_outputs(
|
| 146 |
+
*float_adder(
|
| 147 |
+
*create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=True
|
| 148 |
+
)
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Complete pipelined design
|
| 152 |
+
with set_working_block(adder_pipelined_block):
|
| 153 |
+
create_outputs(
|
| 154 |
+
float_adder_pipelined(
|
| 155 |
+
*create_inputs(float_a=bits, float_b=bits),
|
| 156 |
+
dtype=dtype,
|
| 157 |
+
fast=False,
|
| 158 |
+
)
|
| 159 |
+
)
|
| 160 |
+
|
| 161 |
+
with set_working_block(adder_pipelined_fast_block):
|
| 162 |
+
create_outputs(
|
| 163 |
+
float_adder_pipelined(
|
| 164 |
+
*create_inputs(float_a=bits, float_b=bits),
|
| 165 |
+
dtype=dtype,
|
| 166 |
+
fast=True,
|
| 167 |
+
)
|
| 168 |
+
)
|
| 169 |
+
|
| 170 |
+
# Stages 1 & 2
|
| 171 |
+
with set_working_block(stage_2_block):
|
| 172 |
+
float_components = extract_float_components(
|
| 173 |
+
*create_inputs(float_a=bits, float_b=bits),
|
| 174 |
+
e_bits=e_bits,
|
| 175 |
+
m_bits=m_bits,
|
| 176 |
+
)
|
| 177 |
+
stage_2_outputs = adder_stage_2(
|
| 178 |
+
*float_components,
|
| 179 |
+
e_bits,
|
| 180 |
+
m_bits,
|
| 181 |
+
fast=False,
|
| 182 |
+
)
|
| 183 |
+
create_outputs(*stage_2_outputs)
|
| 184 |
+
|
| 185 |
+
with set_working_block(stage_2_fast_block):
|
| 186 |
+
float_components = extract_float_components(
|
| 187 |
+
*create_inputs(float_a=bits, float_b=bits),
|
| 188 |
+
e_bits=e_bits,
|
| 189 |
+
m_bits=m_bits,
|
| 190 |
+
)
|
| 191 |
+
stage_2_outputs = adder_stage_2(
|
| 192 |
+
*float_components,
|
| 193 |
+
e_bits,
|
| 194 |
+
m_bits,
|
| 195 |
+
fast=True,
|
| 196 |
+
)
|
| 197 |
+
create_outputs(*stage_2_outputs)
|
| 198 |
+
|
| 199 |
+
# Stage 3
|
| 200 |
+
with set_working_block(stage_3_block):
|
| 201 |
+
# Perform alignment and generate SGR bits
|
| 202 |
+
stage_3_outputs = adder_stage_3(
|
| 203 |
+
*create_inputs(mant_smaller=m_bits + 1, shift_amount=e_bits),
|
| 204 |
+
e_bits=e_bits,
|
| 205 |
+
m_bits=m_bits,
|
| 206 |
+
)
|
| 207 |
+
create_outputs(*stage_3_outputs)
|
| 208 |
+
|
| 209 |
+
# Stage 4
|
| 210 |
+
with set_working_block(stage_4_block):
|
| 211 |
+
# Perform mantissa addition and leading zero detection
|
| 212 |
+
stage_4_outputs = adder_stage_4(
|
| 213 |
+
*create_inputs(mant_aligned=m_bits + 1, mant_unchanged=m_bits + 1, s_xor=1),
|
| 214 |
+
m_bits=m_bits,
|
| 215 |
+
fast=False,
|
| 216 |
+
)
|
| 217 |
+
create_outputs(*stage_4_outputs)
|
| 218 |
+
|
| 219 |
+
with set_working_block(stage_4_fast_block):
|
| 220 |
+
# Perform mantissa addition and leading zero detection
|
| 221 |
+
stage_4_outputs = adder_stage_4(
|
| 222 |
+
*create_inputs(mant_aligned=m_bits + 1, mant_unchanged=m_bits + 1, s_xor=1),
|
| 223 |
+
m_bits=m_bits,
|
| 224 |
+
fast=True,
|
| 225 |
+
)
|
| 226 |
+
create_outputs(*stage_4_outputs)
|
| 227 |
+
|
| 228 |
+
# Stage 5
|
| 229 |
+
with set_working_block(stage_5_block):
|
| 230 |
+
# Perform normalization, rounding, and final assembly
|
| 231 |
+
stage_5_outputs = adder_stage_5(
|
| 232 |
+
*create_inputs(
|
| 233 |
+
abs_mantissa=m_bits + 2,
|
| 234 |
+
sticky_bit=1,
|
| 235 |
+
guard_bit=1,
|
| 236 |
+
round_bit=1,
|
| 237 |
+
lzc=4,
|
| 238 |
+
exp_larger=e_bits,
|
| 239 |
+
sign_a=1,
|
| 240 |
+
sign_b=1,
|
| 241 |
+
exp_diff=e_bits + 1,
|
| 242 |
+
is_neg=1,
|
| 243 |
+
),
|
| 244 |
+
e_bits=e_bits,
|
| 245 |
+
m_bits=m_bits,
|
| 246 |
+
)
|
| 247 |
+
create_outputs(*stage_5_outputs)
|
| 248 |
+
|
| 249 |
+
# Return all the generated blocks for analysis
|
| 250 |
+
return {
|
| 251 |
+
"adder_combinational": combinational_block,
|
| 252 |
+
"adder_combinational_fast": combinational_fast_block,
|
| 253 |
+
"adder_pipelined": adder_pipelined_block,
|
| 254 |
+
"adder_pipelined_fast": adder_pipelined_fast_block,
|
| 255 |
+
"adder_stage_2": stage_2_block,
|
| 256 |
+
"adder_stage_2_fast": stage_2_fast_block,
|
| 257 |
+
"adder_stage_3": stage_3_block,
|
| 258 |
+
"adder_stage_4": stage_4_block,
|
| 259 |
+
"adder_stage_4_fast": stage_4_fast_block,
|
| 260 |
+
"adder_stage_5": stage_5_block,
|
| 261 |
+
}
|
| 262 |
+
|
| 263 |
+
|
| 264 |
+
def create_multiplier_blocks(dtype: Type[BaseFloat], fast: bool) -> dict[str, Block]:
|
| 265 |
+
bits = dtype.bitwidth()
|
| 266 |
+
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 267 |
+
|
| 268 |
+
combinational_block = pyrtl.Block()
|
| 269 |
+
multiplier_block = pyrtl.Block()
|
| 270 |
+
stage_2_block = pyrtl.Block()
|
| 271 |
+
stage_3_block = pyrtl.Block()
|
| 272 |
+
stage_4_block = pyrtl.Block()
|
| 273 |
+
|
| 274 |
+
# Combinational design
|
| 275 |
+
with set_working_block(combinational_block):
|
| 276 |
+
create_outputs(
|
| 277 |
+
float_multiplier(
|
| 278 |
+
*create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=fast
|
| 279 |
+
)
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
# Complete pipelined design
|
| 283 |
+
with set_working_block(multiplier_block):
|
| 284 |
+
multiplier = FloatMultiplierPipelined(
|
| 285 |
+
*create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=fast
|
| 286 |
+
)
|
| 287 |
+
create_outputs(multiplier._result)
|
| 288 |
+
|
| 289 |
+
# Stage 1 & 2: Extract components and calculate sign, exponent sum, mantissa product
|
| 290 |
+
with set_working_block(stage_2_block):
|
| 291 |
+
float_components = extract_float_components(
|
| 292 |
+
*create_inputs(float_a=bits, float_b=bits),
|
| 293 |
+
e_bits=e_bits,
|
| 294 |
+
m_bits=m_bits,
|
| 295 |
+
)
|
| 296 |
+
stage_2_outputs = multiplier_stage_2(
|
| 297 |
+
*float_components,
|
| 298 |
+
m_bits,
|
| 299 |
+
fast,
|
| 300 |
+
)
|
| 301 |
+
create_outputs(*stage_2_outputs)
|
| 302 |
+
|
| 303 |
+
# Stage 3: Leading zero detection and exponent adjustment
|
| 304 |
+
with set_working_block(stage_3_block):
|
| 305 |
+
stage_3_outputs = multiplier_stage_3(
|
| 306 |
+
*create_inputs(exp_sum=e_bits + 1, mant_product=2 * m_bits + 2),
|
| 307 |
+
e_bits=e_bits,
|
| 308 |
+
m_bits=m_bits,
|
| 309 |
+
fast=fast,
|
| 310 |
+
)
|
| 311 |
+
create_outputs(*stage_3_outputs)
|
| 312 |
+
|
| 313 |
+
# Stage 4: Normalization, rounding, and final assembly
|
| 314 |
+
with set_working_block(stage_4_block):
|
| 315 |
+
stage_4_outputs = multiplier_stage_4(
|
| 316 |
+
*create_inputs(
|
| 317 |
+
unbiased_exp=e_bits,
|
| 318 |
+
leading_zeros=e_bits,
|
| 319 |
+
mantissa_product=2 * m_bits + 2,
|
| 320 |
+
),
|
| 321 |
+
m_bits=m_bits,
|
| 322 |
+
e_bits=e_bits,
|
| 323 |
+
fast=fast,
|
| 324 |
+
)
|
| 325 |
+
create_outputs(*stage_4_outputs)
|
| 326 |
+
|
| 327 |
+
# Return all the generated blocks for analysis
|
| 328 |
+
faststr = "_fast" if fast else ""
|
| 329 |
+
return {
|
| 330 |
+
f"multiplier_combinational{faststr}": combinational_block,
|
| 331 |
+
f"multiplier{faststr}": multiplier_block,
|
| 332 |
+
f"multiplier_stage_2{faststr}": stage_2_block,
|
| 333 |
+
f"multiplier_stage_3{faststr}": stage_3_block,
|
| 334 |
+
f"multiplier_stage_4{faststr}": stage_4_block,
|
| 335 |
+
}
|
| 336 |
+
|
| 337 |
+
|
| 338 |
+
def create_lmul_blocks(dtype: Type[BaseFloat]) -> dict[str, Block]:
|
| 339 |
+
bits = dtype.bitwidth()
|
| 340 |
+
|
| 341 |
+
combinational_block = pyrtl.Block()
|
| 342 |
+
combinational_fast_block = pyrtl.Block()
|
| 343 |
+
pipelined_block = pyrtl.Block()
|
| 344 |
+
pipelined_fast_block = pyrtl.Block()
|
| 345 |
+
|
| 346 |
+
# Combinational design (simple)
|
| 347 |
+
with set_working_block(combinational_block):
|
| 348 |
+
create_outputs(
|
| 349 |
+
lmul_simple(*create_inputs(float_a=bits, float_b=bits), dtype=dtype)
|
| 350 |
+
)
|
| 351 |
+
|
| 352 |
+
# Combinational design (fast)
|
| 353 |
+
with set_working_block(combinational_fast_block):
|
| 354 |
+
create_outputs(
|
| 355 |
+
lmul_fast(*create_inputs(float_a=bits, float_b=bits), dtype=dtype)
|
| 356 |
+
)
|
| 357 |
+
|
| 358 |
+
# Pipelined design (simple)
|
| 359 |
+
with set_working_block(pipelined_block):
|
| 360 |
+
mult = LmulPipelined(
|
| 361 |
+
*create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=False
|
| 362 |
+
)
|
| 363 |
+
create_outputs(mult.output_reg)
|
| 364 |
+
|
| 365 |
+
# Pipelined design (fast)
|
| 366 |
+
with set_working_block(pipelined_fast_block):
|
| 367 |
+
mult = LmulPipelined(
|
| 368 |
+
*create_inputs(float_a=bits, float_b=bits), dtype=dtype, fast=True
|
| 369 |
+
)
|
| 370 |
+
create_outputs(mult.output_reg)
|
| 371 |
+
|
| 372 |
+
# Return all the generated blocks for analysis
|
| 373 |
+
return {
|
| 374 |
+
"lmul_combinational_simple": combinational_block,
|
| 375 |
+
"lmul_combinational_fast": combinational_fast_block,
|
| 376 |
+
"lmul_pipelined_simple": pipelined_block,
|
| 377 |
+
"lmul_pipelined_fast": pipelined_fast_block,
|
| 378 |
+
}
|
| 379 |
+
|
| 380 |
+
|
| 381 |
+
def connect_pe_io(pe: ProcessingElement):
|
| 382 |
+
# Connect the inputs and outputs of the processing element
|
| 383 |
+
w_bits, a_bits = pe.weight_type.bitwidth(), pe.data_type.bitwidth()
|
| 384 |
+
w_in, d_in, acc_in = create_inputs(
|
| 385 |
+
weight_in=w_bits, data_in=a_bits, accum_in=a_bits
|
| 386 |
+
)
|
| 387 |
+
pe.connect_weight(w_in)
|
| 388 |
+
pe.connect_data(d_in)
|
| 389 |
+
pe.connect_accum(acc_in)
|
| 390 |
+
pe.connect_control_signals(
|
| 391 |
+
*create_inputs(weight_en=1, data_en=1, mul_en=1, adder_en=1)
|
| 392 |
+
)
|
| 393 |
+
create_outputs(*pe.outputs.__dict__.values())
|
| 394 |
+
|
| 395 |
+
|
| 396 |
+
def create_pe_blocks(
|
| 397 |
+
dtypes: tuple[Type[BaseFloat], Type[BaseFloat]],
|
| 398 |
+
) -> dict[str, Block]:
|
| 399 |
+
"""Create a processing element for each pair of dtypes."""
|
| 400 |
+
|
| 401 |
+
weight_dtype, act_dtype = dtypes
|
| 402 |
+
|
| 403 |
+
# Defining blocks to encapsulate hardware
|
| 404 |
+
|
| 405 |
+
combinational_block = Block()
|
| 406 |
+
simple_pipeline_block = Block()
|
| 407 |
+
simple_pipeline_fast_block = Block()
|
| 408 |
+
full_pipeline_block = Block()
|
| 409 |
+
full_pipeline_fast_block = Block()
|
| 410 |
+
|
| 411 |
+
combinational_lmul_block = Block()
|
| 412 |
+
simple_pipeline_lmul_block = Block()
|
| 413 |
+
simple_pipeline_fast_lmul_block = Block()
|
| 414 |
+
full_pipeline_lmul_block = Block()
|
| 415 |
+
full_pipeline_fast_lmul_block = Block()
|
| 416 |
+
|
| 417 |
+
# Standard IEEE multiplier versions
|
| 418 |
+
|
| 419 |
+
with set_working_block(combinational_block):
|
| 420 |
+
pe = ProcessingElement(
|
| 421 |
+
data_type=act_dtype,
|
| 422 |
+
weight_type=weight_dtype,
|
| 423 |
+
accum_type=act_dtype,
|
| 424 |
+
multiplier=float_multiplier,
|
| 425 |
+
adder=float_adder,
|
| 426 |
+
pipeline_mult=False,
|
| 427 |
+
)
|
| 428 |
+
connect_pe_io(pe)
|
| 429 |
+
|
| 430 |
+
with set_working_block(simple_pipeline_block):
|
| 431 |
+
pe = ProcessingElement(
|
| 432 |
+
data_type=act_dtype,
|
| 433 |
+
weight_type=weight_dtype,
|
| 434 |
+
accum_type=act_dtype,
|
| 435 |
+
multiplier=float_multiplier,
|
| 436 |
+
adder=float_adder,
|
| 437 |
+
pipeline_mult=True,
|
| 438 |
+
)
|
| 439 |
+
connect_pe_io(pe)
|
| 440 |
+
|
| 441 |
+
with set_working_block(simple_pipeline_fast_block):
|
| 442 |
+
pe = ProcessingElement(
|
| 443 |
+
data_type=act_dtype,
|
| 444 |
+
weight_type=weight_dtype,
|
| 445 |
+
accum_type=act_dtype,
|
| 446 |
+
multiplier=float_multiplier_fast_unstable,
|
| 447 |
+
adder=float_adder_fast_unstable,
|
| 448 |
+
pipeline_mult=True,
|
| 449 |
+
)
|
| 450 |
+
connect_pe_io(pe)
|
| 451 |
+
|
| 452 |
+
with set_working_block(full_pipeline_block):
|
| 453 |
+
pe = ProcessingElement(
|
| 454 |
+
data_type=act_dtype,
|
| 455 |
+
weight_type=weight_dtype,
|
| 456 |
+
accum_type=act_dtype,
|
| 457 |
+
multiplier=float_multiplier_pipelined,
|
| 458 |
+
adder=float_adder_pipelined,
|
| 459 |
+
pipeline_mult=True,
|
| 460 |
+
)
|
| 461 |
+
connect_pe_io(pe)
|
| 462 |
+
|
| 463 |
+
with set_working_block(full_pipeline_fast_block):
|
| 464 |
+
pe = ProcessingElement(
|
| 465 |
+
data_type=act_dtype,
|
| 466 |
+
weight_type=weight_dtype,
|
| 467 |
+
accum_type=act_dtype,
|
| 468 |
+
multiplier=float_multiplier_pipelined_fast_unstable,
|
| 469 |
+
adder=float_adder_pipelined_fast_unstable,
|
| 470 |
+
pipeline_mult=True,
|
| 471 |
+
)
|
| 472 |
+
connect_pe_io(pe)
|
| 473 |
+
|
| 474 |
+
# L-mul versions
|
| 475 |
+
|
| 476 |
+
with set_working_block(combinational_lmul_block):
|
| 477 |
+
pe = ProcessingElement(
|
| 478 |
+
data_type=act_dtype,
|
| 479 |
+
weight_type=weight_dtype,
|
| 480 |
+
accum_type=act_dtype,
|
| 481 |
+
multiplier=lmul_simple,
|
| 482 |
+
adder=float_adder,
|
| 483 |
+
pipeline_mult=False,
|
| 484 |
+
)
|
| 485 |
+
connect_pe_io(pe)
|
| 486 |
+
|
| 487 |
+
with set_working_block(simple_pipeline_lmul_block):
|
| 488 |
+
pe = ProcessingElement(
|
| 489 |
+
data_type=act_dtype,
|
| 490 |
+
weight_type=weight_dtype,
|
| 491 |
+
accum_type=act_dtype,
|
| 492 |
+
multiplier=lmul_simple,
|
| 493 |
+
adder=float_adder,
|
| 494 |
+
pipeline_mult=True,
|
| 495 |
+
)
|
| 496 |
+
connect_pe_io(pe)
|
| 497 |
+
|
| 498 |
+
with set_working_block(simple_pipeline_fast_lmul_block):
|
| 499 |
+
pe = ProcessingElement(
|
| 500 |
+
data_type=act_dtype,
|
| 501 |
+
weight_type=weight_dtype,
|
| 502 |
+
accum_type=act_dtype,
|
| 503 |
+
multiplier=lmul_fast,
|
| 504 |
+
adder=float_adder_fast_unstable,
|
| 505 |
+
pipeline_mult=True,
|
| 506 |
+
)
|
| 507 |
+
connect_pe_io(pe)
|
| 508 |
+
|
| 509 |
+
with set_working_block(full_pipeline_lmul_block):
|
| 510 |
+
pe = ProcessingElement(
|
| 511 |
+
data_type=act_dtype,
|
| 512 |
+
weight_type=weight_dtype,
|
| 513 |
+
accum_type=act_dtype,
|
| 514 |
+
multiplier=lmul_pipelined,
|
| 515 |
+
adder=float_adder_pipelined,
|
| 516 |
+
pipeline_mult=True,
|
| 517 |
+
)
|
| 518 |
+
connect_pe_io(pe)
|
| 519 |
+
|
| 520 |
+
with set_working_block(full_pipeline_fast_lmul_block):
|
| 521 |
+
pe = ProcessingElement(
|
| 522 |
+
data_type=act_dtype,
|
| 523 |
+
weight_type=weight_dtype,
|
| 524 |
+
accum_type=act_dtype,
|
| 525 |
+
multiplier=lmul_pipelined_fast,
|
| 526 |
+
adder=float_adder_pipelined_fast_unstable,
|
| 527 |
+
pipeline_mult=True,
|
| 528 |
+
)
|
| 529 |
+
connect_pe_io(pe)
|
| 530 |
+
|
| 531 |
+
return {
|
| 532 |
+
"pe_combinational": combinational_block,
|
| 533 |
+
"pe_standard": simple_pipeline_block,
|
| 534 |
+
"pe_fast": simple_pipeline_fast_block,
|
| 535 |
+
"pe_pipelined": full_pipeline_block,
|
| 536 |
+
"pe_fast_pipelined": full_pipeline_fast_block,
|
| 537 |
+
"pe_combinational_lmul": combinational_lmul_block,
|
| 538 |
+
"pe_standard_lmul": simple_pipeline_lmul_block,
|
| 539 |
+
"pe_fast_lmul": simple_pipeline_fast_lmul_block,
|
| 540 |
+
"pe_pipelined_lmul": full_pipeline_lmul_block,
|
| 541 |
+
"pe_fast_pipelined_lmul": full_pipeline_fast_lmul_block,
|
| 542 |
+
}
|
| 543 |
+
|
| 544 |
+
|
| 545 |
+
def create_accelerator_blocks(
|
| 546 |
+
dtypes: tuple[Type[BaseFloat], Type[BaseFloat]],
|
| 547 |
+
array_size: int = 4,
|
| 548 |
+
addr_bits: int = 12,
|
| 549 |
+
) -> dict[str, Block]:
|
| 550 |
+
"""
|
| 551 |
+
Create accelerator blocks for all valid configurations based on the given inputs.
|
| 552 |
+
|
| 553 |
+
Args:
|
| 554 |
+
dtypes: Tuple of (weight_type, activation_type) data types
|
| 555 |
+
array_size: Size of the systolic array (N x N)
|
| 556 |
+
addr_bits: Bit width for accumulator address (uses default if None)
|
| 557 |
+
|
| 558 |
+
Returns:
|
| 559 |
+
Dictionary mapping configuration names to PyRTL blocks
|
| 560 |
+
"""
|
| 561 |
+
weight_type, activation_type = dtypes
|
| 562 |
+
|
| 563 |
+
# Define all valid configurations to test
|
| 564 |
+
pipeline_options = [None, "low", "high"]
|
| 565 |
+
lmul_options = [False, True]
|
| 566 |
+
fast_options = [False, True]
|
| 567 |
+
|
| 568 |
+
# Create configs and blocks
|
| 569 |
+
blocks = {}
|
| 570 |
+
for pipeline, lmul, fast in product(pipeline_options, lmul_options, fast_options):
|
| 571 |
+
if pipeline is None and fast is True:
|
| 572 |
+
continue
|
| 573 |
+
|
| 574 |
+
# Create the configuration
|
| 575 |
+
config = AcceleratorAnalysisConfig(
|
| 576 |
+
array_size=array_size,
|
| 577 |
+
activation_type=activation_type,
|
| 578 |
+
weight_type=weight_type,
|
| 579 |
+
lmul=lmul,
|
| 580 |
+
accum_addr_width=addr_bits,
|
| 581 |
+
pipeline_level=pipeline,
|
| 582 |
+
use_fast_internals=fast,
|
| 583 |
+
)
|
| 584 |
+
|
| 585 |
+
block = pyrtl.Block()
|
| 586 |
+
with set_working_block(block):
|
| 587 |
+
AcceleratorTopLevel(config)
|
| 588 |
+
|
| 589 |
+
blocks[config.name] = block
|
| 590 |
+
|
| 591 |
+
return blocks
|
| 592 |
+
|
| 593 |
+
|
| 594 |
+
################################################################
|
| 595 |
+
|
| 596 |
+
# if __name__ == "__main__":
|
| 597 |
+
|
| 598 |
+
# OUTPUT_DIR = Path("verilog")
|
| 599 |
+
# POSTSYNTH_DIR = OUTPUT_DIR / "pyrtl_synth"
|
| 600 |
+
|
| 601 |
+
# EXPORT_PRE_SYNTH = False
|
| 602 |
+
# EXPORT_POST_SYNTH = True
|
| 603 |
+
# RUN_ANALYSIS = True
|
| 604 |
+
# ANALYSIS_RESULT_DIR = Path("results")
|
| 605 |
+
|
| 606 |
+
# array_size = 8
|
| 607 |
+
# addr_bits = 12
|
| 608 |
+
|
| 609 |
+
# dtype_list = [Float8, BF16, Float32]
|
| 610 |
+
|
| 611 |
+
# dtype_names = {Float8: "fp8", BF16: "bf16", Float32: "fp32"}
|
| 612 |
+
|
| 613 |
+
# weight_act_dtypes = [
|
| 614 |
+
# (Float8, Float8),
|
| 615 |
+
# (Float8, BF16),
|
| 616 |
+
# (Float8, Float32),
|
| 617 |
+
# (BF16, BF16),
|
| 618 |
+
# (BF16, Float32),
|
| 619 |
+
# (Float32, Float32),
|
| 620 |
+
# ]
|
| 621 |
+
|
| 622 |
+
# # Hardware building blocks
|
| 623 |
+
# basic_component_analysis = []
|
| 624 |
+
|
| 625 |
+
# for dtype in dtype_list:
|
| 626 |
+
# block_dicts = [
|
| 627 |
+
# ("adder", create_adder_blocks(dtype)),
|
| 628 |
+
# ("multiplier", create_multiplier_blocks(dtype, fast=False)),
|
| 629 |
+
# ("multiplier", create_multiplier_blocks(dtype, fast=True)),
|
| 630 |
+
# ("lmul", create_lmul_blocks(dtype)),
|
| 631 |
+
# ]
|
| 632 |
+
# for component_name, block_dict in block_dicts:
|
| 633 |
+
# for name, block in block_dict.items():
|
| 634 |
+
# output_path = Path(component_name, dtype_names[dtype], f"{name}.v")
|
| 635 |
+
# if EXPORT_PRE_SYNTH:
|
| 636 |
+
# export_to_verilog(block, OUTPUT_DIR / output_path)
|
| 637 |
+
# if RUN_ANALYSIS:
|
| 638 |
+
# analysis_result = analyze(block, name=name)
|
| 639 |
+
# analysis_result.dtype = dtype_names[dtype]
|
| 640 |
+
# analysis_result.component = component_name
|
| 641 |
+
# basic_component_analysis.append(analysis_result.__dict__)
|
| 642 |
+
# if EXPORT_POST_SYNTH:
|
| 643 |
+
# export_to_verilog(block, POSTSYNTH_DIR / output_path)
|
| 644 |
+
|
| 645 |
+
# # More complex hardware
|
| 646 |
+
# pe_analysis = []
|
| 647 |
+
# accelerator_analysis = []
|
| 648 |
+
|
| 649 |
+
# for weight_dtype, act_dtype in weight_act_dtypes:
|
| 650 |
+
# folder_name = f"w{weight_dtype.bitwidth()}a{act_dtype.bitwidth()}"
|
| 651 |
+
|
| 652 |
+
# pe_blocks = create_pe_blocks((weight_dtype, act_dtype))
|
| 653 |
+
# for name, block in pe_blocks.items():
|
| 654 |
+
# pe_output_path = Path("pe", folder_name, f"{name}.v")
|
| 655 |
+
# if EXPORT_PRE_SYNTH:
|
| 656 |
+
# export_to_verilog(block, OUTPUT_DIR / pe_output_path)
|
| 657 |
+
# if RUN_ANALYSIS:
|
| 658 |
+
# analysis_result = analyze(block, name=name)
|
| 659 |
+
# analysis_result.weights = dtype_names[weight_dtype]
|
| 660 |
+
# analysis_result.activations = dtype_names[act_dtype]
|
| 661 |
+
# analysis_result.component = "pe"
|
| 662 |
+
# pe_analysis.append(analysis_result.__dict__)
|
| 663 |
+
# if EXPORT_POST_SYNTH:
|
| 664 |
+
# export_to_verilog(block, POSTSYNTH_DIR / pe_output_path)
|
| 665 |
+
|
| 666 |
+
# accelerator_blocks = create_accelerator_blocks(
|
| 667 |
+
# (weight_dtype, act_dtype), array_size, addr_bits
|
| 668 |
+
# )
|
| 669 |
+
# for name, block in accelerator_blocks.items():
|
| 670 |
+
# accelerator_output_path = Path("accelerator", folder_name, f"{name}.v")
|
| 671 |
+
# if EXPORT_PRE_SYNTH:
|
| 672 |
+
# export_to_verilog(block, OUTPUT_DIR / accelerator_output_path)
|
| 673 |
+
# if RUN_ANALYSIS:
|
| 674 |
+
# analysis_result = analyze(block, name=name)
|
| 675 |
+
# analysis_result.weights = dtype_names[weight_dtype]
|
| 676 |
+
# analysis_result.activations = dtype_names[act_dtype]
|
| 677 |
+
# analysis_result.component = "accelerator"
|
| 678 |
+
# accelerator_analysis.append(analysis_result.__dict__)
|
| 679 |
+
# if EXPORT_POST_SYNTH:
|
| 680 |
+
# export_to_verilog(block, POSTSYNTH_DIR / accelerator_output_path)
|
| 681 |
+
|
| 682 |
+
# if RUN_ANALYSIS:
|
| 683 |
+
# DataFrame(basic_component_analysis).to_csv(
|
| 684 |
+
# ANALYSIS_RESULT_DIR / "component_analysis.csv", index=False
|
| 685 |
+
# )
|
| 686 |
+
# DataFrame(pe_analysis).to_csv(
|
| 687 |
+
# ANALYSIS_RESULT_DIR / "pe_analysis.csv", index=False
|
| 688 |
+
# )
|
| 689 |
+
# DataFrame(accelerator_analysis).to_csv(
|
| 690 |
+
# ANALYSIS_RESULT_DIR / "accelerator_analysis.csv", index=False
|
| 691 |
+
# )
|
| 692 |
+
|
| 693 |
+
|
| 694 |
+
import multiprocessing as mp
|
| 695 |
+
from pathlib import Path
|
| 696 |
+
import os
|
| 697 |
+
import csv
|
| 698 |
+
import time
|
| 699 |
+
from functools import partial
|
| 700 |
+
from pandas import DataFrame
|
| 701 |
+
import json
|
| 702 |
+
import traceback
|
| 703 |
+
|
| 704 |
+
|
| 705 |
+
def process_block(
|
| 706 |
+
block,
|
| 707 |
+
name,
|
| 708 |
+
output_dir,
|
| 709 |
+
postsynth_dir,
|
| 710 |
+
export_pre_synth,
|
| 711 |
+
export_post_synth,
|
| 712 |
+
run_analysis,
|
| 713 |
+
analysis_result_dir,
|
| 714 |
+
component_name=None,
|
| 715 |
+
dtype=None,
|
| 716 |
+
weight_dtype=None,
|
| 717 |
+
act_dtype=None,
|
| 718 |
+
dtype_names=None,
|
| 719 |
+
output_path=None,
|
| 720 |
+
):
|
| 721 |
+
"""Process a single block with optional export and analysis"""
|
| 722 |
+
result = None
|
| 723 |
+
try:
|
| 724 |
+
if export_pre_synth and output_path:
|
| 725 |
+
os.makedirs((output_dir / output_path).parent, exist_ok=True)
|
| 726 |
+
export_to_verilog(block, output_dir / output_path)
|
| 727 |
+
|
| 728 |
+
if run_analysis:
|
| 729 |
+
analysis_result = analyze(block, name=name)
|
| 730 |
+
|
| 731 |
+
# Set appropriate attributes based on the component type
|
| 732 |
+
if component_name and dtype:
|
| 733 |
+
analysis_result.dtype = dtype_names[dtype]
|
| 734 |
+
analysis_result.component = component_name
|
| 735 |
+
result_type = "component"
|
| 736 |
+
elif weight_dtype and act_dtype:
|
| 737 |
+
analysis_result.weights = dtype_names[weight_dtype]
|
| 738 |
+
analysis_result.activations = dtype_names[act_dtype]
|
| 739 |
+
analysis_result.component = component_name
|
| 740 |
+
result_type = "pe" if component_name == "pe" else "accelerator"
|
| 741 |
+
|
| 742 |
+
result = (result_type, analysis_result.__dict__)
|
| 743 |
+
|
| 744 |
+
if export_post_synth and output_path:
|
| 745 |
+
os.makedirs((postsynth_dir / output_path).parent, exist_ok=True)
|
| 746 |
+
export_to_verilog(block, postsynth_dir / output_path)
|
| 747 |
+
|
| 748 |
+
return result
|
| 749 |
+
except Exception as e:
|
| 750 |
+
error_msg = f"Error processing {name}: {str(e)}\n{traceback.format_exc()}"
|
| 751 |
+
print(error_msg)
|
| 752 |
+
with open(analysis_result_dir / "errors.log", "a") as f:
|
| 753 |
+
f.write(f"{error_msg}\n{'='*80}\n")
|
| 754 |
+
return None
|
| 755 |
+
|
| 756 |
+
|
| 757 |
+
def save_result(result, analysis_result_dir, result_files):
|
| 758 |
+
"""Save a single result to the appropriate CSV file"""
|
| 759 |
+
if not result:
|
| 760 |
+
return
|
| 761 |
+
|
| 762 |
+
result_type, data = result
|
| 763 |
+
|
| 764 |
+
# Get the appropriate file path and headers
|
| 765 |
+
if result_type == "component":
|
| 766 |
+
file_path = analysis_result_dir / "component_analysis.csv"
|
| 767 |
+
elif result_type == "pe":
|
| 768 |
+
file_path = analysis_result_dir / "pe_analysis.csv"
|
| 769 |
+
elif result_type == "accelerator":
|
| 770 |
+
file_path = analysis_result_dir / "accelerator_analysis.csv"
|
| 771 |
+
|
| 772 |
+
# Create directory if it doesn't exist
|
| 773 |
+
os.makedirs(file_path.parent, exist_ok=True)
|
| 774 |
+
|
| 775 |
+
# Check if file exists to determine if we need to write headers
|
| 776 |
+
file_exists = file_path.exists()
|
| 777 |
+
|
| 778 |
+
# Get headers from the data
|
| 779 |
+
headers = list(data.keys())
|
| 780 |
+
|
| 781 |
+
# Open file in append mode
|
| 782 |
+
with open(file_path, "a", newline="") as f:
|
| 783 |
+
writer = csv.DictWriter(f, fieldnames=headers)
|
| 784 |
+
|
| 785 |
+
# Write headers if file doesn't exist
|
| 786 |
+
if not file_exists:
|
| 787 |
+
writer.writeheader()
|
| 788 |
+
|
| 789 |
+
# Write the data row
|
| 790 |
+
writer.writerow(data)
|
| 791 |
+
|
| 792 |
+
# Track that we've written to this file
|
| 793 |
+
result_files.add(file_path)
|
| 794 |
+
|
| 795 |
+
|
| 796 |
+
def result_callback(result, analysis_result_dir, result_files):
|
| 797 |
+
"""Callback function for when a process completes"""
|
| 798 |
+
if result:
|
| 799 |
+
save_result(result, analysis_result_dir, result_files)
|
| 800 |
+
|
| 801 |
+
|
| 802 |
+
if __name__ == "__main__":
|
| 803 |
+
# Create a multiprocessing pool with as many processes as CPU cores
|
| 804 |
+
pool = mp.Pool(processes=mp.cpu_count())
|
| 805 |
+
|
| 806 |
+
# Set to track which result files we've written to
|
| 807 |
+
result_files = set()
|
| 808 |
+
|
| 809 |
+
OUTPUT_DIR = Path("verilog")
|
| 810 |
+
POSTSYNTH_DIR = OUTPUT_DIR / "pyrtl_synth"
|
| 811 |
+
|
| 812 |
+
EXPORT_PRE_SYNTH = False
|
| 813 |
+
EXPORT_POST_SYNTH = True
|
| 814 |
+
RUN_ANALYSIS = True
|
| 815 |
+
ANALYSIS_RESULT_DIR = Path("results")
|
| 816 |
+
|
| 817 |
+
# Create output directories
|
| 818 |
+
os.makedirs(OUTPUT_DIR, exist_ok=True)
|
| 819 |
+
os.makedirs(POSTSYNTH_DIR, exist_ok=True)
|
| 820 |
+
os.makedirs(ANALYSIS_RESULT_DIR, exist_ok=True)
|
| 821 |
+
|
| 822 |
+
array_size = 8
|
| 823 |
+
addr_bits = 12
|
| 824 |
+
|
| 825 |
+
dtype_list = [Float8, BF16, Float32]
|
| 826 |
+
|
| 827 |
+
dtype_names = {Float8: "fp8", BF16: "bf16", Float32: "fp32"}
|
| 828 |
+
|
| 829 |
+
weight_act_dtypes = [
|
| 830 |
+
(Float8, Float8),
|
| 831 |
+
(Float8, BF16),
|
| 832 |
+
(Float8, Float32),
|
| 833 |
+
(BF16, BF16),
|
| 834 |
+
(BF16, Float32),
|
| 835 |
+
(Float32, Float32),
|
| 836 |
+
]
|
| 837 |
+
|
| 838 |
+
# Create a partial function with common arguments
|
| 839 |
+
process_block_partial = partial(
|
| 840 |
+
process_block,
|
| 841 |
+
output_dir=OUTPUT_DIR,
|
| 842 |
+
postsynth_dir=POSTSYNTH_DIR,
|
| 843 |
+
export_pre_synth=EXPORT_PRE_SYNTH,
|
| 844 |
+
export_post_synth=EXPORT_POST_SYNTH,
|
| 845 |
+
run_analysis=RUN_ANALYSIS,
|
| 846 |
+
analysis_result_dir=ANALYSIS_RESULT_DIR,
|
| 847 |
+
dtype_names=dtype_names,
|
| 848 |
+
)
|
| 849 |
+
|
| 850 |
+
# Create a callback function with common arguments
|
| 851 |
+
callback = partial(
|
| 852 |
+
result_callback,
|
| 853 |
+
analysis_result_dir=ANALYSIS_RESULT_DIR,
|
| 854 |
+
result_files=result_files,
|
| 855 |
+
)
|
| 856 |
+
|
| 857 |
+
# Track all submitted tasks
|
| 858 |
+
tasks = []
|
| 859 |
+
|
| 860 |
+
# Hardware building blocks
|
| 861 |
+
for dtype in dtype_list:
|
| 862 |
+
block_dicts = [
|
| 863 |
+
("adder", create_adder_blocks(dtype)),
|
| 864 |
+
("multiplier", create_multiplier_blocks(dtype, fast=False)),
|
| 865 |
+
("multiplier", create_multiplier_blocks(dtype, fast=True)),
|
| 866 |
+
("lmul", create_lmul_blocks(dtype)),
|
| 867 |
+
]
|
| 868 |
+
|
| 869 |
+
for component_name, block_dict in block_dicts:
|
| 870 |
+
for name, block in block_dict.items():
|
| 871 |
+
output_path = Path(component_name, dtype_names[dtype], f"{name}.v")
|
| 872 |
+
|
| 873 |
+
# Submit task to process pool
|
| 874 |
+
task = pool.apply_async(
|
| 875 |
+
process_block_partial,
|
| 876 |
+
args=(block, name),
|
| 877 |
+
kwds={
|
| 878 |
+
"component_name": component_name,
|
| 879 |
+
"dtype": dtype,
|
| 880 |
+
"output_path": output_path,
|
| 881 |
+
},
|
| 882 |
+
callback=callback,
|
| 883 |
+
)
|
| 884 |
+
tasks.append(task)
|
| 885 |
+
|
| 886 |
+
# More complex hardware
|
| 887 |
+
for weight_dtype, act_dtype in weight_act_dtypes:
|
| 888 |
+
folder_name = f"w{weight_dtype.bitwidth()}a{act_dtype.bitwidth()}"
|
| 889 |
+
|
| 890 |
+
# Process PE blocks
|
| 891 |
+
pe_blocks = create_pe_blocks((weight_dtype, act_dtype))
|
| 892 |
+
for name, block in pe_blocks.items():
|
| 893 |
+
pe_output_path = Path("pe", folder_name, f"{name}.v")
|
| 894 |
+
|
| 895 |
+
task = pool.apply_async(
|
| 896 |
+
process_block_partial,
|
| 897 |
+
args=(block, name),
|
| 898 |
+
kwds={
|
| 899 |
+
"component_name": "pe",
|
| 900 |
+
"weight_dtype": weight_dtype,
|
| 901 |
+
"act_dtype": act_dtype,
|
| 902 |
+
"output_path": pe_output_path,
|
| 903 |
+
},
|
| 904 |
+
callback=callback,
|
| 905 |
+
)
|
| 906 |
+
tasks.append(task)
|
| 907 |
+
|
| 908 |
+
# Process accelerator blocks
|
| 909 |
+
accelerator_blocks = create_accelerator_blocks(
|
| 910 |
+
(weight_dtype, act_dtype), array_size, addr_bits
|
| 911 |
+
)
|
| 912 |
+
for name, block in accelerator_blocks.items():
|
| 913 |
+
accelerator_output_path = Path("accelerator", folder_name, f"{name}.v")
|
| 914 |
+
|
| 915 |
+
task = pool.apply_async(
|
| 916 |
+
process_block_partial,
|
| 917 |
+
args=(block, name),
|
| 918 |
+
kwds={
|
| 919 |
+
"component_name": "accelerator",
|
| 920 |
+
"weight_dtype": weight_dtype,
|
| 921 |
+
"act_dtype": act_dtype,
|
| 922 |
+
"output_path": accelerator_output_path,
|
| 923 |
+
},
|
| 924 |
+
callback=callback,
|
| 925 |
+
)
|
| 926 |
+
tasks.append(task)
|
| 927 |
+
|
| 928 |
+
# Wait for all tasks to complete
|
| 929 |
+
try:
|
| 930 |
+
# Monitor progress
|
| 931 |
+
total_tasks = len(tasks)
|
| 932 |
+
completed = 0
|
| 933 |
+
print(f"Processing {total_tasks} tasks using {mp.cpu_count()} processes")
|
| 934 |
+
|
| 935 |
+
while completed < total_tasks:
|
| 936 |
+
new_completed = sum(1 for task in tasks if task.ready())
|
| 937 |
+
if new_completed > completed:
|
| 938 |
+
completed = new_completed
|
| 939 |
+
print(
|
| 940 |
+
f"Progress: {completed}/{total_tasks} tasks completed ({completed/total_tasks*100:.1f}%)"
|
| 941 |
+
)
|
| 942 |
+
time.sleep(1)
|
| 943 |
+
|
| 944 |
+
# Make sure all tasks are properly completed
|
| 945 |
+
for task in tasks:
|
| 946 |
+
task.get() # This will raise any exceptions that occurred in the task
|
| 947 |
+
|
| 948 |
+
except KeyboardInterrupt:
|
| 949 |
+
print("Process interrupted by user. Partial results have been saved.")
|
| 950 |
+
except Exception as e:
|
| 951 |
+
print(f"An error occurred: {str(e)}")
|
| 952 |
+
print("Partial results have been saved.")
|
| 953 |
+
finally:
|
| 954 |
+
# Close the pool
|
| 955 |
+
pool.close()
|
| 956 |
+
pool.join()
|
| 957 |
+
|
| 958 |
+
print(f"Results saved to: {', '.join(str(f) for f in result_files)}")
|
hardware_accelerators/analysis/hardware_stats.py
ADDED
|
@@ -0,0 +1,458 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pandas as pd
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import Literal, Dict, Tuple, List, Any, Optional
|
| 4 |
+
|
| 5 |
+
# Mapping from UI options to dataframe values
|
| 6 |
+
DTYPE_MAP = {"float8": "fp8", "bfloat16": "bf16", "float32": "fp32"}
|
| 7 |
+
|
| 8 |
+
SPEED_MAP = {"Fast": True, "Efficient": False}
|
| 9 |
+
|
| 10 |
+
PIPELINE_MAP = {
|
| 11 |
+
"None": "combinational",
|
| 12 |
+
"Low": "combinational",
|
| 13 |
+
"Full": "pipelined",
|
| 14 |
+
"High": "pipelined",
|
| 15 |
+
}
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def filter_components(
|
| 19 |
+
df: pd.DataFrame, operation: str, dtype: str, is_fast: bool, architecture: str
|
| 20 |
+
) -> pd.DataFrame:
|
| 21 |
+
"""
|
| 22 |
+
Filter the dataframe to get components matching the specified criteria.
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
df: DataFrame containing component data
|
| 26 |
+
operation: Type of operation ('multiplier', 'lmul', 'adder')
|
| 27 |
+
dtype: Data type ('fp8', 'bf16', 'fp32')
|
| 28 |
+
is_fast: Whether to use fast components
|
| 29 |
+
architecture: Architecture type ('combinational', 'pipelined')
|
| 30 |
+
|
| 31 |
+
Returns:
|
| 32 |
+
Filtered DataFrame
|
| 33 |
+
"""
|
| 34 |
+
filtered_df = df[
|
| 35 |
+
(df["operation"] == operation)
|
| 36 |
+
& (df["dtype"] == dtype)
|
| 37 |
+
& (df["is_fast"] == is_fast)
|
| 38 |
+
& (df["architecture"] == architecture)
|
| 39 |
+
]
|
| 40 |
+
|
| 41 |
+
if filtered_df.empty:
|
| 42 |
+
# If no exact match, try without the is_fast constraint
|
| 43 |
+
filtered_df = df[
|
| 44 |
+
(df["operation"] == operation)
|
| 45 |
+
& (df["dtype"] == dtype)
|
| 46 |
+
& (df["architecture"] == architecture)
|
| 47 |
+
]
|
| 48 |
+
|
| 49 |
+
if filtered_df.empty:
|
| 50 |
+
# If still no match, try without architecture constraint
|
| 51 |
+
filtered_df = df[(df["operation"] == operation) & (df["dtype"] == dtype)]
|
| 52 |
+
|
| 53 |
+
return filtered_df
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_pe_components(
|
| 57 |
+
df: pd.DataFrame, mult_type: str, dtype: str, is_fast: bool, architecture: str
|
| 58 |
+
) -> Tuple[pd.Series, pd.Series]:
|
| 59 |
+
"""
|
| 60 |
+
Get the multiplier/lmul and adder components for a PE.
|
| 61 |
+
|
| 62 |
+
Args:
|
| 63 |
+
df: DataFrame containing component data
|
| 64 |
+
mult_type: Type of multiplier ('multiplier' or 'lmul')
|
| 65 |
+
dtype: Data type ('fp8', 'bf16', 'fp32')
|
| 66 |
+
is_fast: Whether to use fast components
|
| 67 |
+
architecture: Architecture type ('combinational', 'pipelined')
|
| 68 |
+
|
| 69 |
+
Returns:
|
| 70 |
+
Tuple of (multiplier, adder) Series
|
| 71 |
+
"""
|
| 72 |
+
mult_df = filter_components(df, mult_type, dtype, is_fast, architecture)
|
| 73 |
+
adder_df = filter_components(df, "adder", dtype, is_fast, architecture)
|
| 74 |
+
|
| 75 |
+
if mult_df.empty or adder_df.empty:
|
| 76 |
+
raise ValueError(
|
| 77 |
+
f"Could not find components for {mult_type}, {dtype}, {is_fast}, {architecture}"
|
| 78 |
+
)
|
| 79 |
+
|
| 80 |
+
# Get the first matching component
|
| 81 |
+
multiplier = mult_df.iloc[0]
|
| 82 |
+
adder = adder_df.iloc[0]
|
| 83 |
+
|
| 84 |
+
return multiplier, adder
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
def calculate_pe_metrics(multiplier: pd.Series, adder: pd.Series) -> Dict[str, float]:
|
| 88 |
+
"""
|
| 89 |
+
Calculate metrics for a single PE (Processing Element).
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
multiplier: Series containing multiplier component data
|
| 93 |
+
adder: Series containing adder component data
|
| 94 |
+
|
| 95 |
+
Returns:
|
| 96 |
+
Dictionary of PE metrics
|
| 97 |
+
"""
|
| 98 |
+
# Area and power are additive
|
| 99 |
+
area = multiplier["area"] + adder["area"]
|
| 100 |
+
power = multiplier["power"] + adder["power"]
|
| 101 |
+
|
| 102 |
+
# Critical path depends on architecture
|
| 103 |
+
if (
|
| 104 |
+
multiplier["architecture"] == "combinational"
|
| 105 |
+
and adder["architecture"] == "combinational"
|
| 106 |
+
):
|
| 107 |
+
delay = max(multiplier["max_arrival_time"], adder["max_arrival_time"])
|
| 108 |
+
pipeline_stages = 1
|
| 109 |
+
else:
|
| 110 |
+
# For pipelined designs, we assume the critical path is the slowest stage
|
| 111 |
+
delay = max(multiplier["max_arrival_time"], adder["max_arrival_time"])
|
| 112 |
+
# Estimate pipeline stages
|
| 113 |
+
if multiplier["architecture"] == "pipelined":
|
| 114 |
+
mult_stages = 4 # Assumption for pipelined multiplier
|
| 115 |
+
else:
|
| 116 |
+
mult_stages = 1
|
| 117 |
+
|
| 118 |
+
if adder["architecture"] == "pipelined":
|
| 119 |
+
add_stages = 5 # Assumption for pipelined adder
|
| 120 |
+
else:
|
| 121 |
+
add_stages = 1
|
| 122 |
+
|
| 123 |
+
pipeline_stages = mult_stages + add_stages - 1 # -1 because they share a stage
|
| 124 |
+
|
| 125 |
+
# Calculate performance metrics
|
| 126 |
+
clock_freq_ghz = 1.0 / delay # GHz, assuming delay is in ns
|
| 127 |
+
ops_per_cycle = 2 # 1 multiply + 1 add = 2 FLOPs per cycle
|
| 128 |
+
tflops = clock_freq_ghz * ops_per_cycle / 1000 # TFLOPS for a single PE
|
| 129 |
+
|
| 130 |
+
# Efficiency metrics
|
| 131 |
+
tflops_per_watt = tflops / power
|
| 132 |
+
tflops_per_mm2 = tflops / (area * 1e-6) # Convert area to mm²
|
| 133 |
+
power_density = power / (area * 1e-6) # W/mm²
|
| 134 |
+
|
| 135 |
+
return {
|
| 136 |
+
"area_um2": area,
|
| 137 |
+
"area_mm2": area * 1e-6,
|
| 138 |
+
"power_w": power,
|
| 139 |
+
"delay_ns": delay,
|
| 140 |
+
"clock_freq_ghz": clock_freq_ghz,
|
| 141 |
+
"pipeline_stages": pipeline_stages,
|
| 142 |
+
"tflops": tflops,
|
| 143 |
+
"tflops_per_watt": tflops_per_watt,
|
| 144 |
+
"tflops_per_mm2": tflops_per_mm2,
|
| 145 |
+
"power_density": power_density,
|
| 146 |
+
"energy_per_op_pj": (power / (clock_freq_ghz * ops_per_cycle * 1e9))
|
| 147 |
+
* 1e12, # pJ per operation
|
| 148 |
+
}
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def calculate_array_metrics(
|
| 152 |
+
pe_metrics: Dict[str, float], array_size: int, num_cores: int
|
| 153 |
+
) -> Dict[str, float]:
|
| 154 |
+
"""
|
| 155 |
+
Calculate metrics for a systolic array and the entire accelerator.
|
| 156 |
+
|
| 157 |
+
Args:
|
| 158 |
+
pe_metrics: Dictionary of PE metrics
|
| 159 |
+
array_size: Size of the systolic array (NxN)
|
| 160 |
+
num_cores: Number of accelerator cores
|
| 161 |
+
|
| 162 |
+
Returns:
|
| 163 |
+
Dictionary of array and accelerator metrics
|
| 164 |
+
"""
|
| 165 |
+
# Number of PEs per array and total
|
| 166 |
+
pes_per_array = array_size * array_size
|
| 167 |
+
total_pes = pes_per_array * num_cores
|
| 168 |
+
|
| 169 |
+
# Scale metrics for a single array
|
| 170 |
+
array_area_mm2 = pe_metrics["area_mm2"] * pes_per_array
|
| 171 |
+
array_power_w = pe_metrics["power_w"] * pes_per_array
|
| 172 |
+
|
| 173 |
+
# Scale metrics for the entire accelerator
|
| 174 |
+
total_area_mm2 = array_area_mm2 * num_cores
|
| 175 |
+
total_power_w = array_power_w * num_cores
|
| 176 |
+
|
| 177 |
+
# Performance scales with the number of PEs
|
| 178 |
+
array_tflops = pe_metrics["tflops"] * pes_per_array
|
| 179 |
+
total_tflops = array_tflops * num_cores
|
| 180 |
+
|
| 181 |
+
# Efficiency metrics
|
| 182 |
+
total_tflops_per_watt = total_tflops / total_power_w
|
| 183 |
+
total_tflops_per_mm2 = total_tflops / total_area_mm2
|
| 184 |
+
|
| 185 |
+
# Latency calculation
|
| 186 |
+
# For an NxN array, data takes 2N-1 cycles to flow through
|
| 187 |
+
# Plus pipeline_stages-1 cycles for the pipeline to fill
|
| 188 |
+
pipeline_latency_cycles = pe_metrics["pipeline_stages"] - 1
|
| 189 |
+
array_latency_cycles = 2 * array_size - 1
|
| 190 |
+
total_latency_cycles = array_latency_cycles + pipeline_latency_cycles
|
| 191 |
+
|
| 192 |
+
# Time per cycle based on clock frequency
|
| 193 |
+
cycle_time_ns = 1.0 / pe_metrics["clock_freq_ghz"]
|
| 194 |
+
|
| 195 |
+
# Total latency in ns
|
| 196 |
+
total_latency_ns = total_latency_cycles * cycle_time_ns
|
| 197 |
+
|
| 198 |
+
# Throughput after pipeline is filled (ops per second)
|
| 199 |
+
throughput_ops_per_second = (
|
| 200 |
+
pe_metrics["clock_freq_ghz"] * 1e9 * pes_per_array * 2
|
| 201 |
+
) # 2 ops per PE per cycle
|
| 202 |
+
total_throughput_ops_per_second = throughput_ops_per_second * num_cores
|
| 203 |
+
|
| 204 |
+
# Energy per matrix multiplication
|
| 205 |
+
# Assuming an NxN matrix multiply requires N³ operations
|
| 206 |
+
ops_per_matmul = array_size**3
|
| 207 |
+
energy_per_matmul_nj = (
|
| 208 |
+
(array_power_w / throughput_ops_per_second) * ops_per_matmul * 1e9
|
| 209 |
+
) # nJ
|
| 210 |
+
|
| 211 |
+
# Inference metrics (assuming a simple MLP with 3 layers)
|
| 212 |
+
# Each layer requires a matrix multiplication
|
| 213 |
+
num_layers = 3
|
| 214 |
+
inference_ops = ops_per_matmul * num_layers
|
| 215 |
+
inference_latency_ns = (inference_ops / throughput_ops_per_second) * 1e9
|
| 216 |
+
inference_energy_uj = (
|
| 217 |
+
(total_power_w / total_throughput_ops_per_second) * inference_ops * 1e6
|
| 218 |
+
) # uJ
|
| 219 |
+
|
| 220 |
+
return {
|
| 221 |
+
"array_size": array_size,
|
| 222 |
+
"num_cores": num_cores,
|
| 223 |
+
"pes_per_array": pes_per_array,
|
| 224 |
+
"total_pes": total_pes,
|
| 225 |
+
"clock_freq_ghz": pe_metrics["clock_freq_ghz"],
|
| 226 |
+
"array_area_mm2": array_area_mm2,
|
| 227 |
+
"total_area_mm2": total_area_mm2,
|
| 228 |
+
"array_power_w": array_power_w,
|
| 229 |
+
"total_power_w": total_power_w,
|
| 230 |
+
"array_tflops": array_tflops,
|
| 231 |
+
"total_tflops": total_tflops,
|
| 232 |
+
"tflops_per_watt": total_tflops_per_watt,
|
| 233 |
+
"tflops_per_mm2": total_tflops_per_mm2,
|
| 234 |
+
"power_density_w_mm2": total_power_w / total_area_mm2,
|
| 235 |
+
"total_latency_cycles": total_latency_cycles,
|
| 236 |
+
"total_latency_ns": total_latency_ns,
|
| 237 |
+
"throughput_gops": total_throughput_ops_per_second / 1e9, # GOPS
|
| 238 |
+
"energy_per_matmul_nj": energy_per_matmul_nj,
|
| 239 |
+
"inference_latency_ns": inference_latency_ns,
|
| 240 |
+
"inference_latency_us": inference_latency_ns / 1e3, # us
|
| 241 |
+
"inference_energy_uj": inference_energy_uj,
|
| 242 |
+
"inferences_per_second": 1e9 / inference_latency_ns,
|
| 243 |
+
"inferences_per_watt": (1e9 / inference_latency_ns) / total_power_w,
|
| 244 |
+
}
|
| 245 |
+
|
| 246 |
+
|
| 247 |
+
def format_metrics_for_display(metrics: Dict[str, float]) -> Dict[str, str]:
|
| 248 |
+
"""
|
| 249 |
+
Format metrics for display in the Gradio UI.
|
| 250 |
+
|
| 251 |
+
Args:
|
| 252 |
+
metrics: Dictionary of metrics
|
| 253 |
+
|
| 254 |
+
Returns:
|
| 255 |
+
Dictionary of formatted metrics
|
| 256 |
+
"""
|
| 257 |
+
formatted = {}
|
| 258 |
+
|
| 259 |
+
# Format area
|
| 260 |
+
formatted["Total Chip Area"] = f"{metrics['total_area_mm2']:.2f} mm²"
|
| 261 |
+
|
| 262 |
+
# Format performance
|
| 263 |
+
formatted["Clock Speed"] = f"{metrics['clock_freq_ghz']:.2f} GHz"
|
| 264 |
+
formatted["Total Performance"] = f"{metrics['total_tflops']:.2f} TFLOPS"
|
| 265 |
+
formatted["Performance per Core"] = f"{metrics['array_tflops']:.2f} TFLOPS"
|
| 266 |
+
formatted["Performance per Watt"] = f"{metrics['tflops_per_watt']:.2f} TFLOPS/W"
|
| 267 |
+
formatted["Performance per Area"] = f"{metrics['tflops_per_mm2']:.2f} TFLOPS/mm²"
|
| 268 |
+
|
| 269 |
+
# Format power
|
| 270 |
+
formatted["Total Power"] = f"{metrics['total_power_w']:.2f} W"
|
| 271 |
+
formatted["Power Density"] = f"{metrics['power_density_w_mm2']:.2f} W/mm²"
|
| 272 |
+
|
| 273 |
+
# Format latency and throughput
|
| 274 |
+
formatted["Matrix Mult Latency"] = f"{metrics['total_latency_ns']:.2f} ns"
|
| 275 |
+
formatted["Inference Latency"] = f"{metrics['inference_latency_us']:.2f} µs"
|
| 276 |
+
formatted["Throughput"] = f"{metrics['throughput_gops']:.2f} GOPS"
|
| 277 |
+
|
| 278 |
+
# Format energy
|
| 279 |
+
formatted["Energy per Matrix Mult"] = f"{metrics['energy_per_matmul_nj']:.2f} nJ"
|
| 280 |
+
# formatted["Inference Energy"] = f"{metrics['inference_energy_uj']:.2f} µJ"
|
| 281 |
+
# formatted["Inferences per Second"] = f"{metrics['inferences_per_second']:.0f}"
|
| 282 |
+
# formatted["Inferences per Watt"] = f"{metrics['inferences_per_watt']:.0f}"
|
| 283 |
+
|
| 284 |
+
return formatted
|
| 285 |
+
|
| 286 |
+
|
| 287 |
+
def calculate_hardware_stats(
|
| 288 |
+
df: pd.DataFrame,
|
| 289 |
+
activation_type: Literal["float8", "bfloat16", "float32"],
|
| 290 |
+
weight_type: Literal["float8", "bfloat16", "float32"],
|
| 291 |
+
systolic_array_size: int,
|
| 292 |
+
num_accelerator_cores: int,
|
| 293 |
+
fast_internals: Literal["Fast", "Efficient"],
|
| 294 |
+
pipeline_level: Literal["None", "Low", "Full"],
|
| 295 |
+
process_node_size: Optional[Literal["7nm", "45nm", "130nm"]] = None,
|
| 296 |
+
) -> Tuple[Dict[str, str], Dict[str, str]]:
|
| 297 |
+
"""
|
| 298 |
+
Calculate hardware statistics for both lmul and standard IEEE multiplier configurations.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
df: DataFrame containing component data
|
| 302 |
+
activation_type: Type of activations
|
| 303 |
+
weight_type: Type of weights
|
| 304 |
+
systolic_array_size: Size of the systolic array
|
| 305 |
+
num_accelerator_cores: Number of accelerator cores
|
| 306 |
+
fast_internals: Whether to use fast or efficient components
|
| 307 |
+
pipeline_level: Level of pipelining
|
| 308 |
+
process_node_size: Process node size (ignored for now)
|
| 309 |
+
|
| 310 |
+
Returns:
|
| 311 |
+
Tuple of (lmul_metrics, ieee_metrics) dictionaries
|
| 312 |
+
"""
|
| 313 |
+
# Map UI options to dataframe values
|
| 314 |
+
act_dtype = DTYPE_MAP[activation_type]
|
| 315 |
+
weight_dtype = DTYPE_MAP[weight_type]
|
| 316 |
+
is_fast = SPEED_MAP[fast_internals]
|
| 317 |
+
architecture = PIPELINE_MAP[pipeline_level]
|
| 318 |
+
|
| 319 |
+
# For mixed precision, use the larger precision for the PE
|
| 320 |
+
pe_dtype = (
|
| 321 |
+
act_dtype
|
| 322 |
+
if DTYPE_MAP[activation_type] >= DTYPE_MAP[weight_type]
|
| 323 |
+
else weight_dtype
|
| 324 |
+
)
|
| 325 |
+
|
| 326 |
+
# Calculate metrics for lmul configuration
|
| 327 |
+
try:
|
| 328 |
+
lmul_mult, lmul_adder = get_pe_components(
|
| 329 |
+
df, "lmul", pe_dtype, is_fast, architecture
|
| 330 |
+
)
|
| 331 |
+
lmul_pe_metrics = calculate_pe_metrics(lmul_mult, lmul_adder)
|
| 332 |
+
lmul_array_metrics = calculate_array_metrics(
|
| 333 |
+
lmul_pe_metrics, systolic_array_size, num_accelerator_cores
|
| 334 |
+
)
|
| 335 |
+
lmul_formatted = format_metrics_for_display(lmul_array_metrics)
|
| 336 |
+
except ValueError as e:
|
| 337 |
+
# If lmul components not found, return error message
|
| 338 |
+
lmul_formatted = {"Error": f"Could not find lmul components: {str(e)}"}
|
| 339 |
+
|
| 340 |
+
# Calculate metrics for standard IEEE multiplier configuration
|
| 341 |
+
try:
|
| 342 |
+
ieee_mult, ieee_adder = get_pe_components(
|
| 343 |
+
df, "multiplier", pe_dtype, is_fast, architecture
|
| 344 |
+
)
|
| 345 |
+
ieee_pe_metrics = calculate_pe_metrics(ieee_mult, ieee_adder)
|
| 346 |
+
ieee_array_metrics = calculate_array_metrics(
|
| 347 |
+
ieee_pe_metrics, systolic_array_size, num_accelerator_cores
|
| 348 |
+
)
|
| 349 |
+
ieee_formatted = format_metrics_for_display(ieee_array_metrics)
|
| 350 |
+
except ValueError as e:
|
| 351 |
+
# If IEEE components not found, return error message
|
| 352 |
+
ieee_formatted = {"Error": f"Could not find IEEE components: {str(e)}"}
|
| 353 |
+
|
| 354 |
+
return lmul_formatted, ieee_formatted
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def calculate_comparison_metrics(
|
| 358 |
+
lmul_metrics: Dict[str, str], ieee_metrics: Dict[str, str]
|
| 359 |
+
) -> Dict[str, str]:
|
| 360 |
+
"""
|
| 361 |
+
Calculate comparison metrics between lmul and IEEE configurations.
|
| 362 |
+
|
| 363 |
+
Args:
|
| 364 |
+
lmul_metrics: Dictionary of lmul metrics
|
| 365 |
+
ieee_metrics: Dictionary of IEEE metrics
|
| 366 |
+
|
| 367 |
+
Returns:
|
| 368 |
+
Dictionary of comparison metrics
|
| 369 |
+
"""
|
| 370 |
+
comparison = {}
|
| 371 |
+
|
| 372 |
+
# Check if there was an error in either calculation
|
| 373 |
+
if "Error" in lmul_metrics or "Error" in ieee_metrics:
|
| 374 |
+
return {"Error": "Cannot calculate comparison due to missing components"}
|
| 375 |
+
|
| 376 |
+
# Extract numeric values from formatted strings
|
| 377 |
+
def extract_number(s):
|
| 378 |
+
return float(s.split()[0])
|
| 379 |
+
|
| 380 |
+
# Calculate percentage improvements
|
| 381 |
+
try:
|
| 382 |
+
# Area improvement (lower is better)
|
| 383 |
+
lmul_area = extract_number(lmul_metrics["Total Chip Area"])
|
| 384 |
+
ieee_area = extract_number(ieee_metrics["Total Chip Area"])
|
| 385 |
+
area_improvement = (1 - lmul_area / ieee_area) * 100
|
| 386 |
+
comparison["Area Reduction"] = f"{area_improvement:.1f}%"
|
| 387 |
+
|
| 388 |
+
# Performance improvement (higher is better)
|
| 389 |
+
lmul_perf = extract_number(lmul_metrics["Total Performance"])
|
| 390 |
+
ieee_perf = extract_number(ieee_metrics["Total Performance"])
|
| 391 |
+
perf_improvement = (lmul_perf / ieee_perf - 1) * 100
|
| 392 |
+
comparison["Performance Improvement"] = f"{perf_improvement:.1f}%"
|
| 393 |
+
|
| 394 |
+
# Power efficiency improvement (higher is better)
|
| 395 |
+
lmul_eff = extract_number(lmul_metrics["Performance per Watt"])
|
| 396 |
+
ieee_eff = extract_number(ieee_metrics["Performance per Watt"])
|
| 397 |
+
eff_improvement = (lmul_eff / ieee_eff - 1) * 100
|
| 398 |
+
comparison["Efficiency Improvement"] = f"{eff_improvement:.1f}%"
|
| 399 |
+
|
| 400 |
+
# Latency improvement (lower is better)
|
| 401 |
+
lmul_latency = extract_number(lmul_metrics["Inference Latency"])
|
| 402 |
+
ieee_latency = extract_number(ieee_metrics["Inference Latency"])
|
| 403 |
+
latency_improvement = (1 - lmul_latency / ieee_latency) * 100
|
| 404 |
+
comparison["Latency Reduction"] = f"{latency_improvement:.1f}%"
|
| 405 |
+
|
| 406 |
+
# Energy improvement (lower is better)
|
| 407 |
+
lmul_energy = extract_number(lmul_metrics["Inference Energy"])
|
| 408 |
+
ieee_energy = extract_number(ieee_metrics["Inference Energy"])
|
| 409 |
+
energy_improvement = (1 - lmul_energy / ieee_energy) * 100
|
| 410 |
+
comparison["Energy Reduction"] = f"{energy_improvement:.1f}%"
|
| 411 |
+
|
| 412 |
+
except (ValueError, KeyError) as e:
|
| 413 |
+
comparison["Error"] = f"Error calculating comparisons: {str(e)}"
|
| 414 |
+
|
| 415 |
+
return comparison
|
| 416 |
+
|
| 417 |
+
|
| 418 |
+
# Example usage in the Gradio app:
|
| 419 |
+
def update_hardware_stats(
|
| 420 |
+
df: pd.DataFrame,
|
| 421 |
+
activation_type: str,
|
| 422 |
+
weight_type: str,
|
| 423 |
+
systolic_array_size: int,
|
| 424 |
+
num_accelerator_cores: int,
|
| 425 |
+
fast_internals: str,
|
| 426 |
+
pipeline_level: str,
|
| 427 |
+
process_node_size: str,
|
| 428 |
+
) -> Tuple[Dict[str, str], Dict[str, str], Dict[str, str]]:
|
| 429 |
+
"""
|
| 430 |
+
Update hardware statistics for the Gradio app.
|
| 431 |
+
|
| 432 |
+
Args:
|
| 433 |
+
df: DataFrame containing component data
|
| 434 |
+
activation_type: Type of activations
|
| 435 |
+
weight_type: Type of weights
|
| 436 |
+
systolic_array_size: Size of the systolic array
|
| 437 |
+
num_accelerator_cores: Number of accelerator cores
|
| 438 |
+
fast_internals: Whether to use fast or efficient components
|
| 439 |
+
pipeline_level: Level of pipelining
|
| 440 |
+
process_node_size: Process node size
|
| 441 |
+
|
| 442 |
+
Returns:
|
| 443 |
+
Tuple of (lmul_metrics, ieee_metrics, comparison_metrics) dictionaries
|
| 444 |
+
"""
|
| 445 |
+
lmul_metrics, ieee_metrics = calculate_hardware_stats(
|
| 446 |
+
df,
|
| 447 |
+
activation_type,
|
| 448 |
+
weight_type,
|
| 449 |
+
systolic_array_size,
|
| 450 |
+
num_accelerator_cores,
|
| 451 |
+
fast_internals,
|
| 452 |
+
pipeline_level,
|
| 453 |
+
process_node_size,
|
| 454 |
+
)
|
| 455 |
+
|
| 456 |
+
comparison_metrics = calculate_comparison_metrics(lmul_metrics, ieee_metrics)
|
| 457 |
+
|
| 458 |
+
return lmul_metrics, ieee_metrics, comparison_metrics
|
hardware_accelerators/analysis/mnist_eval.py
ADDED
|
@@ -0,0 +1,274 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Evaluation function
|
| 2 |
+
from itertools import product
|
| 3 |
+
import os
|
| 4 |
+
import pandas as pd
|
| 5 |
+
from pyrtl import WireVector
|
| 6 |
+
import torch
|
| 7 |
+
from typing import Callable
|
| 8 |
+
import numpy as np
|
| 9 |
+
from torchvision import datasets, transforms
|
| 10 |
+
from torch.utils.data import DataLoader
|
| 11 |
+
from torch.nn import CrossEntropyLoss
|
| 12 |
+
import multiprocessing as mp
|
| 13 |
+
from tqdm.auto import tqdm
|
| 14 |
+
import time
|
| 15 |
+
import csv
|
| 16 |
+
from pathlib import Path
|
| 17 |
+
import traceback
|
| 18 |
+
|
| 19 |
+
from ..simulation.matrix_utils import count_batch_gemm_tiles
|
| 20 |
+
|
| 21 |
+
from .config import (
|
| 22 |
+
NN_TEST_MUL_FNS,
|
| 23 |
+
NN_TEST_SYSTOLIC_ARRAY_SIZE,
|
| 24 |
+
NN_TEST_ACCUM_ADDR_WIDTH,
|
| 25 |
+
NN_TEST_WA_DTYPES,
|
| 26 |
+
NN_TEST_BATCH_SIZE,
|
| 27 |
+
)
|
| 28 |
+
from hardware_accelerators.dtypes import *
|
| 29 |
+
from hardware_accelerators.simulation.accelerator import CompiledAcceleratorSimulator
|
| 30 |
+
from hardware_accelerators.rtllib.accelerator import CompiledAcceleratorConfig
|
| 31 |
+
from hardware_accelerators.rtllib.multipliers import *
|
| 32 |
+
from hardware_accelerators.nn import load_model
|
| 33 |
+
from ..simulation.accelerator import CompiledAcceleratorSimulator
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def generate_test_configs(
|
| 37 |
+
weight_act_dtypes: list[tuple[Type[BaseFloat], Type[BaseFloat]]],
|
| 38 |
+
multiplier_fns: list[
|
| 39 |
+
Callable[[WireVector, WireVector, type[BaseFloat]], WireVector]
|
| 40 |
+
],
|
| 41 |
+
):
|
| 42 |
+
configs = []
|
| 43 |
+
for dtypes, mult_fn in product(weight_act_dtypes, multiplier_fns):
|
| 44 |
+
weight_type, act_type = dtypes
|
| 45 |
+
config = CompiledAcceleratorConfig(
|
| 46 |
+
array_size=NN_TEST_SYSTOLIC_ARRAY_SIZE,
|
| 47 |
+
weight_type=weight_type,
|
| 48 |
+
activation_type=act_type,
|
| 49 |
+
multiplier=mult_fn,
|
| 50 |
+
accum_addr_width=NN_TEST_ACCUM_ADDR_WIDTH,
|
| 51 |
+
)
|
| 52 |
+
configs.append(config)
|
| 53 |
+
return configs
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def evaluate_with_progress(
|
| 57 |
+
config,
|
| 58 |
+
dataset,
|
| 59 |
+
batch_size,
|
| 60 |
+
criterion=CrossEntropyLoss(),
|
| 61 |
+
process_id=0,
|
| 62 |
+
):
|
| 63 |
+
"""Evaluate a model with progress tracking for the entire dataset"""
|
| 64 |
+
# Define a complete result template with default values
|
| 65 |
+
result = {
|
| 66 |
+
"config": config.name,
|
| 67 |
+
"weight_type": config.weight_type.__name__,
|
| 68 |
+
"activation_type": config.activation_type.__name__,
|
| 69 |
+
"multiplier": config.multiplier.__name__,
|
| 70 |
+
"avg_loss": float("nan"),
|
| 71 |
+
"accuracy": float("nan"),
|
| 72 |
+
"total_time": 0,
|
| 73 |
+
"batch_size": batch_size,
|
| 74 |
+
"total_batches": 0,
|
| 75 |
+
"total_samples": 0,
|
| 76 |
+
"samples_per_second": 0,
|
| 77 |
+
"error": None, # Will be None for successful runs
|
| 78 |
+
}
|
| 79 |
+
|
| 80 |
+
try:
|
| 81 |
+
start_time = time.time()
|
| 82 |
+
|
| 83 |
+
# Load the appropriate model based on weight type
|
| 84 |
+
if config.weight_type == Float32:
|
| 85 |
+
model = load_model("./models/mlp_mnist_fp32.pth")
|
| 86 |
+
else:
|
| 87 |
+
model = load_model("./models/mlp_mnist_bf16.pth")
|
| 88 |
+
|
| 89 |
+
# Create simulator
|
| 90 |
+
sim = CompiledAcceleratorSimulator(config, model=model)
|
| 91 |
+
|
| 92 |
+
if not sim.model_loaded:
|
| 93 |
+
result["error"] = "No model loaded"
|
| 94 |
+
return result
|
| 95 |
+
|
| 96 |
+
correct = 0
|
| 97 |
+
total = 0
|
| 98 |
+
running_loss = 0.0
|
| 99 |
+
|
| 100 |
+
data_loader = DataLoader(dataset, batch_size=batch_size, shuffle=False)
|
| 101 |
+
total_batches = len(data_loader)
|
| 102 |
+
tiles_per_batch = count_batch_gemm_tiles(
|
| 103 |
+
sim.hidden_dim, sim.input_dim + 1, sim.config.array_size
|
| 104 |
+
) + count_batch_gemm_tiles(
|
| 105 |
+
sim.output_dim, sim.hidden_dim + 1, sim.config.array_size
|
| 106 |
+
)
|
| 107 |
+
result["total_batches"] = total_batches
|
| 108 |
+
|
| 109 |
+
# Create a progress bar for this specific simulation
|
| 110 |
+
desc = f"Config {config.name} ({config.weight_type.__name__}/{config.activation_type.__name__})"
|
| 111 |
+
with tqdm(
|
| 112 |
+
total=total_batches * tiles_per_batch,
|
| 113 |
+
desc=desc,
|
| 114 |
+
position=process_id + 1,
|
| 115 |
+
leave=False,
|
| 116 |
+
) as pbar:
|
| 117 |
+
for batch, labels in data_loader:
|
| 118 |
+
batch_size_actual = batch.shape[0]
|
| 119 |
+
batch = batch.reshape(batch_size_actual, -1).numpy()
|
| 120 |
+
|
| 121 |
+
# Time the prediction
|
| 122 |
+
outputs = sim.predict_batch(batch, pbar)
|
| 123 |
+
|
| 124 |
+
loss = criterion(torch.tensor(outputs), labels)
|
| 125 |
+
running_loss += loss.item()
|
| 126 |
+
|
| 127 |
+
# Get predictions from the maximum value
|
| 128 |
+
predicted = np.argmax(outputs, axis=1)
|
| 129 |
+
total += labels.size(0)
|
| 130 |
+
correct += (predicted == labels).sum().item()
|
| 131 |
+
|
| 132 |
+
end_time = time.time()
|
| 133 |
+
total_time = end_time - start_time
|
| 134 |
+
|
| 135 |
+
# Update result with actual values
|
| 136 |
+
result.update(
|
| 137 |
+
{
|
| 138 |
+
"avg_loss": running_loss / total_batches,
|
| 139 |
+
"accuracy": 100.0 * correct / total,
|
| 140 |
+
"total_time": total_time,
|
| 141 |
+
"total_samples": total,
|
| 142 |
+
"samples_per_second": total / total_time,
|
| 143 |
+
}
|
| 144 |
+
)
|
| 145 |
+
|
| 146 |
+
return result
|
| 147 |
+
|
| 148 |
+
except Exception as e:
|
| 149 |
+
error_msg = (
|
| 150 |
+
f"Error evaluating {config.name}: {str(e)}\n{traceback.format_exc()}"
|
| 151 |
+
)
|
| 152 |
+
print(error_msg)
|
| 153 |
+
result["error"] = str(e)
|
| 154 |
+
return result
|
| 155 |
+
|
| 156 |
+
|
| 157 |
+
def save_result(result, output_file):
|
| 158 |
+
"""Save a single result to CSV file"""
|
| 159 |
+
file_exists = os.path.isfile(output_file)
|
| 160 |
+
|
| 161 |
+
with open(output_file, "a", newline="") as f:
|
| 162 |
+
writer = csv.DictWriter(f, fieldnames=list(result.keys()))
|
| 163 |
+
if not file_exists:
|
| 164 |
+
writer.writeheader()
|
| 165 |
+
writer.writerow(result)
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
def process_config(config, dataset, batch_size, output_file, process_id):
|
| 169 |
+
"""Process a single configuration and save results"""
|
| 170 |
+
# Set process name for better monitoring
|
| 171 |
+
mp.current_process().name = f"Sim-{config.name}"
|
| 172 |
+
|
| 173 |
+
# print(f"Starting evaluation of {config.name}")
|
| 174 |
+
result = evaluate_with_progress(config, dataset, batch_size, process_id=process_id)
|
| 175 |
+
|
| 176 |
+
# Save result immediately
|
| 177 |
+
save_result(result, output_file)
|
| 178 |
+
|
| 179 |
+
print(
|
| 180 |
+
f"Completed evaluation of {config.name}: Accuracy = {result.get('accuracy', 'ERROR'):.2f}%, "
|
| 181 |
+
f"Time = {result.get('total_time', 0):.2f}s, "
|
| 182 |
+
f"Speed = {result.get('samples_per_second', 0):.2f} samples/s"
|
| 183 |
+
)
|
| 184 |
+
return result
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def main():
|
| 188 |
+
# Create output directory
|
| 189 |
+
output_dir = Path("results")
|
| 190 |
+
output_dir.mkdir(exist_ok=True)
|
| 191 |
+
output_file = output_dir / "mnist_eval.csv"
|
| 192 |
+
|
| 193 |
+
# Data transformation: convert images to tensor and normalize them
|
| 194 |
+
transform = transforms.Compose(
|
| 195 |
+
[
|
| 196 |
+
transforms.ToTensor(),
|
| 197 |
+
transforms.Normalize((0.1307,), (0.3081,)),
|
| 198 |
+
]
|
| 199 |
+
)
|
| 200 |
+
|
| 201 |
+
# Download MNIST test data
|
| 202 |
+
test_dataset = datasets.MNIST(
|
| 203 |
+
root="./data", train=False, download=True, transform=transform
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
configs = generate_test_configs(NN_TEST_WA_DTYPES, NN_TEST_MUL_FNS)
|
| 207 |
+
|
| 208 |
+
# Create a multiprocessing pool
|
| 209 |
+
num_processes = min(len(configs), mp.cpu_count() - 2)
|
| 210 |
+
print(f"Using {num_processes} processes to evaluate {len(configs)} configurations")
|
| 211 |
+
print(
|
| 212 |
+
f"Each simulation will process the entire test dataset with batch size {NN_TEST_BATCH_SIZE}"
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
# Clear the screen and set up for progress bars
|
| 216 |
+
print("\n\n")
|
| 217 |
+
|
| 218 |
+
# Start the pool and process configurations
|
| 219 |
+
with mp.Pool(processes=num_processes) as pool:
|
| 220 |
+
# Create a list to track all tasks
|
| 221 |
+
tasks = []
|
| 222 |
+
|
| 223 |
+
# Submit all tasks
|
| 224 |
+
for i, config in enumerate(configs):
|
| 225 |
+
task = pool.apply_async(
|
| 226 |
+
process_config,
|
| 227 |
+
args=(config, test_dataset, NN_TEST_BATCH_SIZE, output_file, i),
|
| 228 |
+
)
|
| 229 |
+
tasks.append((config.name, task))
|
| 230 |
+
|
| 231 |
+
# Set up the main progress bar at the top
|
| 232 |
+
with tqdm(total=len(tasks), desc="Overall Progress", position=0) as pbar:
|
| 233 |
+
completed = 0
|
| 234 |
+
while completed < len(tasks):
|
| 235 |
+
new_completed = sum(1 for _, task in tasks if task.ready())
|
| 236 |
+
if new_completed > completed:
|
| 237 |
+
pbar.update(new_completed - completed)
|
| 238 |
+
completed = new_completed
|
| 239 |
+
time.sleep(0.5)
|
| 240 |
+
|
| 241 |
+
# Make sure all tasks are properly completed and collect results
|
| 242 |
+
all_results = []
|
| 243 |
+
for config_name, task in tasks:
|
| 244 |
+
try:
|
| 245 |
+
result = task.get()
|
| 246 |
+
all_results.append(result)
|
| 247 |
+
except Exception as e:
|
| 248 |
+
print(f"Error in task {config_name}: {str(e)}")
|
| 249 |
+
|
| 250 |
+
print(f"All evaluations complete. Results saved to {output_file}")
|
| 251 |
+
|
| 252 |
+
# Create a summary DataFrame and display it
|
| 253 |
+
if all_results:
|
| 254 |
+
df = pd.DataFrame(all_results)
|
| 255 |
+
print("\nSummary of Results (sorted by accuracy):")
|
| 256 |
+
summary_cols = [
|
| 257 |
+
"config",
|
| 258 |
+
"weight_type",
|
| 259 |
+
"activation_type",
|
| 260 |
+
"multiplier",
|
| 261 |
+
"accuracy",
|
| 262 |
+
"total_time",
|
| 263 |
+
"samples_per_second",
|
| 264 |
+
]
|
| 265 |
+
print(df[summary_cols].sort_values("accuracy", ascending=False))
|
| 266 |
+
|
| 267 |
+
print("\nSummary of Results (sorted by speed):")
|
| 268 |
+
print(df[summary_cols].sort_values("samples_per_second", ascending=False))
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
if __name__ == "__main__":
|
| 272 |
+
# Set start method for multiprocessing
|
| 273 |
+
mp.set_start_method("spawn", force=True) # Use 'spawn' for better compatibility
|
| 274 |
+
main()
|
hardware_accelerators/analysis/simple_circuits.py
ADDED
|
@@ -0,0 +1,258 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import pyrtl
|
| 2 |
+
from pyrtl import WireVector, Input, Output, Simulation
|
| 3 |
+
import numpy as np
|
| 4 |
+
import sys
|
| 5 |
+
import os
|
| 6 |
+
|
| 7 |
+
# Add the parent directory to the path so we can import from hardware_accelerators
|
| 8 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
| 9 |
+
|
| 10 |
+
from hardware_accelerators.dtypes import Float16, Float32, Float8, BF16
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_simple_adder(data_type):
|
| 14 |
+
"""
|
| 15 |
+
Create a simple adder circuit (a + b) using PyRTL's built-in operators.
|
| 16 |
+
|
| 17 |
+
Args:
|
| 18 |
+
data_type: The data type to use (only used for bitwidth)
|
| 19 |
+
|
| 20 |
+
Returns:
|
| 21 |
+
The PyRTL working block
|
| 22 |
+
"""
|
| 23 |
+
# Clear any existing PyRTL design
|
| 24 |
+
pyrtl.reset_working_block()
|
| 25 |
+
|
| 26 |
+
# Create input and output wires
|
| 27 |
+
a = Input(data_type.bitwidth(), 'a')
|
| 28 |
+
b = Input(data_type.bitwidth(), 'b')
|
| 29 |
+
result = Output(data_type.bitwidth(), 'result')
|
| 30 |
+
|
| 31 |
+
# Create adder using PyRTL's built-in addition
|
| 32 |
+
# Note: This treats the inputs as unsigned integers, not floating point
|
| 33 |
+
result <<= a + b
|
| 34 |
+
|
| 35 |
+
return pyrtl.working_block()
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def create_simple_multiplier(data_type):
|
| 39 |
+
"""
|
| 40 |
+
Create a simple multiplier circuit (a * b) using PyRTL's built-in operators.
|
| 41 |
+
|
| 42 |
+
Args:
|
| 43 |
+
data_type: The data type to use (only used for bitwidth)
|
| 44 |
+
|
| 45 |
+
Returns:
|
| 46 |
+
The PyRTL working block
|
| 47 |
+
"""
|
| 48 |
+
# Clear any existing PyRTL design
|
| 49 |
+
pyrtl.reset_working_block()
|
| 50 |
+
|
| 51 |
+
# Create input and output wires
|
| 52 |
+
a = Input(data_type.bitwidth(), 'a')
|
| 53 |
+
b = Input(data_type.bitwidth(), 'b')
|
| 54 |
+
result = Output(data_type.bitwidth(), 'result')
|
| 55 |
+
|
| 56 |
+
# Create multiplier using PyRTL's built-in multiplication
|
| 57 |
+
# Note: This treats the inputs as unsigned integers, not floating point
|
| 58 |
+
# We'll truncate the result to match the input bitwidth
|
| 59 |
+
mult_result = a * b
|
| 60 |
+
result <<= mult_result[:data_type.bitwidth()]
|
| 61 |
+
|
| 62 |
+
return pyrtl.working_block()
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def create_pipelined_adder(data_type):
|
| 66 |
+
"""
|
| 67 |
+
Create a pipelined adder circuit (a + b) using PyRTL's built-in operators.
|
| 68 |
+
|
| 69 |
+
Args:
|
| 70 |
+
data_type: The data type to use (only used for bitwidth)
|
| 71 |
+
|
| 72 |
+
Returns:
|
| 73 |
+
The PyRTL working block and the result wire
|
| 74 |
+
"""
|
| 75 |
+
# Clear any existing PyRTL design
|
| 76 |
+
pyrtl.reset_working_block()
|
| 77 |
+
|
| 78 |
+
# Create input and output wires
|
| 79 |
+
a = Input(data_type.bitwidth(), 'a')
|
| 80 |
+
b = Input(data_type.bitwidth(), 'b')
|
| 81 |
+
result = Output(data_type.bitwidth(), 'result')
|
| 82 |
+
|
| 83 |
+
# Create pipeline registers
|
| 84 |
+
a_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='a_reg')
|
| 85 |
+
b_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='b_reg')
|
| 86 |
+
|
| 87 |
+
# Connect input to registers
|
| 88 |
+
a_reg.next <<= a
|
| 89 |
+
b_reg.next <<= b
|
| 90 |
+
|
| 91 |
+
# Perform addition in the next stage using PyRTL's built-in addition
|
| 92 |
+
add_result = a_reg + b_reg
|
| 93 |
+
|
| 94 |
+
# Connect to output
|
| 95 |
+
result <<= add_result
|
| 96 |
+
|
| 97 |
+
return pyrtl.working_block(), result
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def create_pipelined_multiplier(data_type):
|
| 101 |
+
"""
|
| 102 |
+
Create a pipelined multiplier circuit (a * b) using PyRTL's built-in operators.
|
| 103 |
+
|
| 104 |
+
Args:
|
| 105 |
+
data_type: The data type to use (only used for bitwidth)
|
| 106 |
+
|
| 107 |
+
Returns:
|
| 108 |
+
The PyRTL working block and the result wire
|
| 109 |
+
"""
|
| 110 |
+
# Clear any existing PyRTL design
|
| 111 |
+
pyrtl.reset_working_block()
|
| 112 |
+
|
| 113 |
+
# Create input and output wires
|
| 114 |
+
a = Input(data_type.bitwidth(), 'a')
|
| 115 |
+
b = Input(data_type.bitwidth(), 'b')
|
| 116 |
+
result = Output(data_type.bitwidth(), 'result')
|
| 117 |
+
|
| 118 |
+
# Create pipeline registers
|
| 119 |
+
a_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='a_reg')
|
| 120 |
+
b_reg = pyrtl.Register(bitwidth=data_type.bitwidth(), name='b_reg')
|
| 121 |
+
|
| 122 |
+
# Connect input to registers
|
| 123 |
+
a_reg.next <<= a
|
| 124 |
+
b_reg.next <<= b
|
| 125 |
+
|
| 126 |
+
# Perform multiplication in the next stage using PyRTL's built-in multiplication
|
| 127 |
+
mult_result = a_reg * b_reg
|
| 128 |
+
|
| 129 |
+
# Truncate the result to match the input bitwidth
|
| 130 |
+
truncated_result = mult_result[:data_type.bitwidth()]
|
| 131 |
+
|
| 132 |
+
# Connect to output
|
| 133 |
+
result <<= truncated_result
|
| 134 |
+
|
| 135 |
+
return pyrtl.working_block(), result
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
def simulate_circuit(block, data_type, num_cycles=10):
|
| 139 |
+
"""
|
| 140 |
+
Simulate a circuit with random inputs.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
block: The PyRTL block to simulate
|
| 144 |
+
data_type: The data type used (only for bitwidth)
|
| 145 |
+
num_cycles: Number of simulation cycles
|
| 146 |
+
|
| 147 |
+
Returns:
|
| 148 |
+
sim: The simulation object
|
| 149 |
+
trace: The simulation trace
|
| 150 |
+
"""
|
| 151 |
+
# Create simulation with tracing enabled
|
| 152 |
+
sim = Simulation()
|
| 153 |
+
|
| 154 |
+
# Create test data (random integers within the valid range for the bitwidth)
|
| 155 |
+
bitwidth = data_type.bitwidth()
|
| 156 |
+
|
| 157 |
+
# Handle large bitwidths safely
|
| 158 |
+
if bitwidth > 30: # Avoid overflow for large bitwidths
|
| 159 |
+
a_values = [np.random.randint(0, 2**16) for _ in range(num_cycles)]
|
| 160 |
+
b_values = [np.random.randint(0, 2**16) for _ in range(num_cycles)]
|
| 161 |
+
else:
|
| 162 |
+
max_val = 2**bitwidth - 1
|
| 163 |
+
a_values = [np.random.randint(0, max_val) for _ in range(num_cycles)]
|
| 164 |
+
b_values = [np.random.randint(0, max_val) for _ in range(num_cycles)]
|
| 165 |
+
|
| 166 |
+
# Create input dictionaries for each cycle
|
| 167 |
+
input_vectors = []
|
| 168 |
+
for i in range(num_cycles):
|
| 169 |
+
cycle_inputs = {
|
| 170 |
+
'a': a_values[i],
|
| 171 |
+
'b': b_values[i]
|
| 172 |
+
}
|
| 173 |
+
input_vectors.append(cycle_inputs)
|
| 174 |
+
|
| 175 |
+
# Run simulation for each cycle
|
| 176 |
+
for i in range(num_cycles):
|
| 177 |
+
sim.step(input_vectors[i])
|
| 178 |
+
|
| 179 |
+
# Get trace from the tracer
|
| 180 |
+
tracer = sim.tracer
|
| 181 |
+
|
| 182 |
+
return sim, tracer
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def main():
|
| 186 |
+
"""Main function to create and test the simplified circuits."""
|
| 187 |
+
# Data types to test (only used for bitwidth)
|
| 188 |
+
data_types = [Float8, BF16, Float16, Float32]
|
| 189 |
+
|
| 190 |
+
# Results dictionary
|
| 191 |
+
results = {
|
| 192 |
+
"simple_adder": {},
|
| 193 |
+
"simple_multiplier": {},
|
| 194 |
+
"pipelined_adder": {},
|
| 195 |
+
"pipelined_multiplier": {}
|
| 196 |
+
}
|
| 197 |
+
|
| 198 |
+
# Test simple adders
|
| 199 |
+
print("=== Testing Simple Adders ===")
|
| 200 |
+
for dtype in data_types:
|
| 201 |
+
print(f"\nCreating and simulating {dtype.__name__} adder (bitwidth: {dtype.bitwidth()})...")
|
| 202 |
+
block = create_simple_adder(dtype)
|
| 203 |
+
sim, tracer = simulate_circuit(block, dtype)
|
| 204 |
+
|
| 205 |
+
# Print some results
|
| 206 |
+
if 'result' in tracer.trace:
|
| 207 |
+
output_values = tracer.trace['result']
|
| 208 |
+
print(f" result: {output_values}")
|
| 209 |
+
|
| 210 |
+
results["simple_adder"][dtype.__name__] = block
|
| 211 |
+
|
| 212 |
+
# Test simple multipliers
|
| 213 |
+
print("\n=== Testing Simple Multipliers ===")
|
| 214 |
+
for dtype in data_types:
|
| 215 |
+
print(f"\nCreating and simulating {dtype.__name__} multiplier (bitwidth: {dtype.bitwidth()})...")
|
| 216 |
+
block = create_simple_multiplier(dtype)
|
| 217 |
+
sim, tracer = simulate_circuit(block, dtype)
|
| 218 |
+
|
| 219 |
+
# Print some results
|
| 220 |
+
if 'result' in tracer.trace:
|
| 221 |
+
output_values = tracer.trace['result']
|
| 222 |
+
print(f" result: {output_values}")
|
| 223 |
+
|
| 224 |
+
results["simple_multiplier"][dtype.__name__] = block
|
| 225 |
+
|
| 226 |
+
# Test pipelined adders
|
| 227 |
+
print("\n=== Testing Pipelined Adders ===")
|
| 228 |
+
for dtype in data_types:
|
| 229 |
+
print(f"\nCreating and simulating {dtype.__name__} pipelined adder (bitwidth: {dtype.bitwidth()})...")
|
| 230 |
+
block, _ = create_pipelined_adder(dtype)
|
| 231 |
+
sim, tracer = simulate_circuit(block, dtype)
|
| 232 |
+
|
| 233 |
+
# Print some results
|
| 234 |
+
if 'result' in tracer.trace:
|
| 235 |
+
output_values = tracer.trace['result']
|
| 236 |
+
print(f" result: {output_values}")
|
| 237 |
+
|
| 238 |
+
results["pipelined_adder"][dtype.__name__] = block
|
| 239 |
+
|
| 240 |
+
# Test pipelined multipliers
|
| 241 |
+
print("\n=== Testing Pipelined Multipliers ===")
|
| 242 |
+
for dtype in data_types:
|
| 243 |
+
print(f"\nCreating and simulating {dtype.__name__} pipelined multiplier (bitwidth: {dtype.bitwidth()})...")
|
| 244 |
+
block, _ = create_pipelined_multiplier(dtype)
|
| 245 |
+
sim, tracer = simulate_circuit(block, dtype)
|
| 246 |
+
|
| 247 |
+
# Print some results
|
| 248 |
+
if 'result' in tracer.trace:
|
| 249 |
+
output_values = tracer.trace['result']
|
| 250 |
+
print(f" result: {output_values}")
|
| 251 |
+
|
| 252 |
+
results["pipelined_multiplier"][dtype.__name__] = block
|
| 253 |
+
|
| 254 |
+
return results
|
| 255 |
+
|
| 256 |
+
|
| 257 |
+
if __name__ == "__main__":
|
| 258 |
+
main()
|
hardware_accelerators/analysis/verilog_export.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import sys
|
| 3 |
+
import pyrtl
|
| 4 |
+
from pyrtl import WireVector, Input, Output, Simulation
|
| 5 |
+
|
| 6 |
+
# Add the parent directory to the path so we can import from hardware_accelerators
|
| 7 |
+
sys.path.append(os.path.dirname(os.path.dirname(os.path.dirname(os.path.abspath(__file__)))))
|
| 8 |
+
|
| 9 |
+
from hardware_accelerators.dtypes import Float16, Float32, Float8, BF16
|
| 10 |
+
from hardware_accelerators.analysis.simple_circuits import (
|
| 11 |
+
create_simple_adder,
|
| 12 |
+
create_simple_multiplier,
|
| 13 |
+
create_pipelined_adder,
|
| 14 |
+
create_pipelined_multiplier
|
| 15 |
+
)
|
| 16 |
+
|
| 17 |
+
def export_to_verilog(block, output_filename, add_reset=True, initialize_registers=False):
|
| 18 |
+
"""
|
| 19 |
+
Export a PyRTL block to a Verilog file.
|
| 20 |
+
|
| 21 |
+
Args:
|
| 22 |
+
block: The PyRTL block to export
|
| 23 |
+
output_filename: The filename to write the Verilog code to
|
| 24 |
+
add_reset: If reset logic should be added. Options are:
|
| 25 |
+
False (no reset logic),
|
| 26 |
+
True (synchronous reset logic),
|
| 27 |
+
'asynchronous' (asynchronous reset logic)
|
| 28 |
+
initialize_registers: Initialize Verilog registers to their reset_value
|
| 29 |
+
"""
|
| 30 |
+
# Create the output directory if it doesn't exist
|
| 31 |
+
os.makedirs(os.path.dirname(output_filename), exist_ok=True)
|
| 32 |
+
|
| 33 |
+
# Export the block to Verilog
|
| 34 |
+
with open(output_filename, 'w') as f:
|
| 35 |
+
pyrtl.output_to_verilog(
|
| 36 |
+
f,
|
| 37 |
+
add_reset=add_reset,
|
| 38 |
+
initialize_registers=initialize_registers,
|
| 39 |
+
block=block
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
print(f"Exported Verilog to {output_filename}")
|
| 43 |
+
|
| 44 |
+
def export_all_circuits():
|
| 45 |
+
"""
|
| 46 |
+
Export all simple circuits to Verilog files.
|
| 47 |
+
"""
|
| 48 |
+
# Create output directory
|
| 49 |
+
output_dir = os.path.join(os.path.dirname(os.path.abspath(__file__)), "verilog_output")
|
| 50 |
+
os.makedirs(output_dir, exist_ok=True)
|
| 51 |
+
|
| 52 |
+
# List of data types to use
|
| 53 |
+
data_types = [Float8, Float16, BF16, Float32]
|
| 54 |
+
|
| 55 |
+
# Export simple adder for each data type
|
| 56 |
+
for dtype in data_types:
|
| 57 |
+
block = create_simple_adder(dtype)
|
| 58 |
+
output_filename = os.path.join(output_dir, f"simple_adder_{dtype.__name__}.v")
|
| 59 |
+
export_to_verilog(block, output_filename)
|
| 60 |
+
|
| 61 |
+
# Export simple multiplier for each data type
|
| 62 |
+
for dtype in data_types:
|
| 63 |
+
block = create_simple_multiplier(dtype)
|
| 64 |
+
output_filename = os.path.join(output_dir, f"simple_multiplier_{dtype.__name__}.v")
|
| 65 |
+
export_to_verilog(block, output_filename)
|
| 66 |
+
|
| 67 |
+
# Export pipelined adder for each data type
|
| 68 |
+
for dtype in data_types:
|
| 69 |
+
block, _ = create_pipelined_adder(dtype)
|
| 70 |
+
output_filename = os.path.join(output_dir, f"pipelined_adder_{dtype.__name__}.v")
|
| 71 |
+
export_to_verilog(block, output_filename, initialize_registers=True)
|
| 72 |
+
|
| 73 |
+
# Export pipelined multiplier for each data type
|
| 74 |
+
for dtype in data_types:
|
| 75 |
+
block, _ = create_pipelined_multiplier(dtype)
|
| 76 |
+
output_filename = os.path.join(output_dir, f"pipelined_multiplier_{dtype.__name__}.v")
|
| 77 |
+
export_to_verilog(block, output_filename, initialize_registers=True)
|
| 78 |
+
|
| 79 |
+
def main():
|
| 80 |
+
# Export all circuits
|
| 81 |
+
export_all_circuits()
|
| 82 |
+
|
| 83 |
+
print("All circuits exported to Verilog successfully!")
|
| 84 |
+
|
| 85 |
+
if __name__ == "__main__":
|
| 86 |
+
main()
|
hardware_accelerators/analysis/verilog_output/pipelined_adder_BF16.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
reg[15:0] a_reg = 16'd0;
|
| 13 |
+
reg[15:0] b_reg = 16'd0;
|
| 14 |
+
|
| 15 |
+
wire[16:0] tmp20;
|
| 16 |
+
wire[15:0] tmp21;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp21;
|
| 20 |
+
assign tmp20 = a_reg + b_reg;
|
| 21 |
+
assign tmp21 = {tmp20[15], tmp20[14], tmp20[13], tmp20[12], tmp20[11], tmp20[10], tmp20[9], tmp20[8], tmp20[7], tmp20[6], tmp20[5], tmp20[4], tmp20[3], tmp20[2], tmp20[1], tmp20[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/pipelined_adder_Float16.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
reg[15:0] a_reg = 16'd0;
|
| 13 |
+
reg[15:0] b_reg = 16'd0;
|
| 14 |
+
|
| 15 |
+
wire[16:0] tmp18;
|
| 16 |
+
wire[15:0] tmp19;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp19;
|
| 20 |
+
assign tmp18 = a_reg + b_reg;
|
| 21 |
+
assign tmp19 = {tmp18[15], tmp18[14], tmp18[13], tmp18[12], tmp18[11], tmp18[10], tmp18[9], tmp18[8], tmp18[7], tmp18[6], tmp18[5], tmp18[4], tmp18[3], tmp18[2], tmp18[1], tmp18[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/pipelined_adder_Float32.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[31:0] a;
|
| 9 |
+
input[31:0] b;
|
| 10 |
+
output[31:0] result;
|
| 11 |
+
|
| 12 |
+
reg[31:0] a_reg = 32'd0;
|
| 13 |
+
reg[31:0] b_reg = 32'd0;
|
| 14 |
+
|
| 15 |
+
wire[32:0] tmp22;
|
| 16 |
+
wire[31:0] tmp23;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp23;
|
| 20 |
+
assign tmp22 = a_reg + b_reg;
|
| 21 |
+
assign tmp23 = {tmp22[31], tmp22[30], tmp22[29], tmp22[28], tmp22[27], tmp22[26], tmp22[25], tmp22[24], tmp22[23], tmp22[22], tmp22[21], tmp22[20], tmp22[19], tmp22[18], tmp22[17], tmp22[16], tmp22[15], tmp22[14], tmp22[13], tmp22[12], tmp22[11], tmp22[10], tmp22[9], tmp22[8], tmp22[7], tmp22[6], tmp22[5], tmp22[4], tmp22[3], tmp22[2], tmp22[1], tmp22[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/pipelined_adder_Float8.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[7:0] a;
|
| 9 |
+
input[7:0] b;
|
| 10 |
+
output[7:0] result;
|
| 11 |
+
|
| 12 |
+
reg[7:0] a_reg = 8'd0;
|
| 13 |
+
reg[7:0] b_reg = 8'd0;
|
| 14 |
+
|
| 15 |
+
wire[8:0] tmp16;
|
| 16 |
+
wire[7:0] tmp17;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp17;
|
| 20 |
+
assign tmp16 = a_reg + b_reg;
|
| 21 |
+
assign tmp17 = {tmp16[7], tmp16[6], tmp16[5], tmp16[4], tmp16[3], tmp16[2], tmp16[1], tmp16[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_BF16.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
reg[15:0] a_reg = 16'd0;
|
| 13 |
+
reg[15:0] b_reg = 16'd0;
|
| 14 |
+
|
| 15 |
+
wire[31:0] tmp28;
|
| 16 |
+
wire[15:0] tmp29;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp29;
|
| 20 |
+
assign tmp28 = a_reg * b_reg;
|
| 21 |
+
assign tmp29 = {tmp28[15], tmp28[14], tmp28[13], tmp28[12], tmp28[11], tmp28[10], tmp28[9], tmp28[8], tmp28[7], tmp28[6], tmp28[5], tmp28[4], tmp28[3], tmp28[2], tmp28[1], tmp28[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float16.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
reg[15:0] a_reg = 16'd0;
|
| 13 |
+
reg[15:0] b_reg = 16'd0;
|
| 14 |
+
|
| 15 |
+
wire[31:0] tmp26;
|
| 16 |
+
wire[15:0] tmp27;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp27;
|
| 20 |
+
assign tmp26 = a_reg * b_reg;
|
| 21 |
+
assign tmp27 = {tmp26[15], tmp26[14], tmp26[13], tmp26[12], tmp26[11], tmp26[10], tmp26[9], tmp26[8], tmp26[7], tmp26[6], tmp26[5], tmp26[4], tmp26[3], tmp26[2], tmp26[1], tmp26[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float32.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[31:0] a;
|
| 9 |
+
input[31:0] b;
|
| 10 |
+
output[31:0] result;
|
| 11 |
+
|
| 12 |
+
reg[31:0] a_reg = 32'd0;
|
| 13 |
+
reg[31:0] b_reg = 32'd0;
|
| 14 |
+
|
| 15 |
+
wire[63:0] tmp30;
|
| 16 |
+
wire[31:0] tmp31;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp31;
|
| 20 |
+
assign tmp30 = a_reg * b_reg;
|
| 21 |
+
assign tmp31 = {tmp30[31], tmp30[30], tmp30[29], tmp30[28], tmp30[27], tmp30[26], tmp30[25], tmp30[24], tmp30[23], tmp30[22], tmp30[21], tmp30[20], tmp30[19], tmp30[18], tmp30[17], tmp30[16], tmp30[15], tmp30[14], tmp30[13], tmp30[12], tmp30[11], tmp30[10], tmp30[9], tmp30[8], tmp30[7], tmp30[6], tmp30[5], tmp30[4], tmp30[3], tmp30[2], tmp30[1], tmp30[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/pipelined_multiplier_Float8.v
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[7:0] a;
|
| 9 |
+
input[7:0] b;
|
| 10 |
+
output[7:0] result;
|
| 11 |
+
|
| 12 |
+
reg[7:0] a_reg = 8'd0;
|
| 13 |
+
reg[7:0] b_reg = 8'd0;
|
| 14 |
+
|
| 15 |
+
wire[15:0] tmp24;
|
| 16 |
+
wire[7:0] tmp25;
|
| 17 |
+
|
| 18 |
+
// Combinational
|
| 19 |
+
assign result = tmp25;
|
| 20 |
+
assign tmp24 = a_reg * b_reg;
|
| 21 |
+
assign tmp25 = {tmp24[7], tmp24[6], tmp24[5], tmp24[4], tmp24[3], tmp24[2], tmp24[1], tmp24[0]};
|
| 22 |
+
|
| 23 |
+
// Registers
|
| 24 |
+
always @(posedge clk)
|
| 25 |
+
begin
|
| 26 |
+
if (rst) begin
|
| 27 |
+
a_reg <= 0;
|
| 28 |
+
b_reg <= 0;
|
| 29 |
+
end
|
| 30 |
+
else begin
|
| 31 |
+
a_reg <= a;
|
| 32 |
+
b_reg <= b;
|
| 33 |
+
end
|
| 34 |
+
end
|
| 35 |
+
|
| 36 |
+
endmodule
|
| 37 |
+
|
hardware_accelerators/analysis/verilog_output/simple_adder_BF16.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
wire[16:0] tmp4;
|
| 13 |
+
wire[15:0] tmp5;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp5;
|
| 17 |
+
assign tmp4 = a + b;
|
| 18 |
+
assign tmp5 = {tmp4[15], tmp4[14], tmp4[13], tmp4[12], tmp4[11], tmp4[10], tmp4[9], tmp4[8], tmp4[7], tmp4[6], tmp4[5], tmp4[4], tmp4[3], tmp4[2], tmp4[1], tmp4[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/analysis/verilog_output/simple_adder_Float16.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
wire[16:0] tmp2;
|
| 13 |
+
wire[15:0] tmp3;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp3;
|
| 17 |
+
assign tmp2 = a + b;
|
| 18 |
+
assign tmp3 = {tmp2[15], tmp2[14], tmp2[13], tmp2[12], tmp2[11], tmp2[10], tmp2[9], tmp2[8], tmp2[7], tmp2[6], tmp2[5], tmp2[4], tmp2[3], tmp2[2], tmp2[1], tmp2[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/analysis/verilog_output/simple_adder_Float32.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[31:0] a;
|
| 9 |
+
input[31:0] b;
|
| 10 |
+
output[31:0] result;
|
| 11 |
+
|
| 12 |
+
wire[32:0] tmp6;
|
| 13 |
+
wire[31:0] tmp7;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp7;
|
| 17 |
+
assign tmp6 = a + b;
|
| 18 |
+
assign tmp7 = {tmp6[31], tmp6[30], tmp6[29], tmp6[28], tmp6[27], tmp6[26], tmp6[25], tmp6[24], tmp6[23], tmp6[22], tmp6[21], tmp6[20], tmp6[19], tmp6[18], tmp6[17], tmp6[16], tmp6[15], tmp6[14], tmp6[13], tmp6[12], tmp6[11], tmp6[10], tmp6[9], tmp6[8], tmp6[7], tmp6[6], tmp6[5], tmp6[4], tmp6[3], tmp6[2], tmp6[1], tmp6[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/analysis/verilog_output/simple_adder_Float8.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[7:0] a;
|
| 9 |
+
input[7:0] b;
|
| 10 |
+
output[7:0] result;
|
| 11 |
+
|
| 12 |
+
wire[8:0] tmp0;
|
| 13 |
+
wire[7:0] tmp1;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp1;
|
| 17 |
+
assign tmp0 = a + b;
|
| 18 |
+
assign tmp1 = {tmp0[7], tmp0[6], tmp0[5], tmp0[4], tmp0[3], tmp0[2], tmp0[1], tmp0[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/analysis/verilog_output/simple_multiplier_BF16.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
wire[31:0] tmp12;
|
| 13 |
+
wire[15:0] tmp13;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp13;
|
| 17 |
+
assign tmp12 = a * b;
|
| 18 |
+
assign tmp13 = {tmp12[15], tmp12[14], tmp12[13], tmp12[12], tmp12[11], tmp12[10], tmp12[9], tmp12[8], tmp12[7], tmp12[6], tmp12[5], tmp12[4], tmp12[3], tmp12[2], tmp12[1], tmp12[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/analysis/verilog_output/simple_multiplier_Float16.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[15:0] a;
|
| 9 |
+
input[15:0] b;
|
| 10 |
+
output[15:0] result;
|
| 11 |
+
|
| 12 |
+
wire[31:0] tmp10;
|
| 13 |
+
wire[15:0] tmp11;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp11;
|
| 17 |
+
assign tmp10 = a * b;
|
| 18 |
+
assign tmp11 = {tmp10[15], tmp10[14], tmp10[13], tmp10[12], tmp10[11], tmp10[10], tmp10[9], tmp10[8], tmp10[7], tmp10[6], tmp10[5], tmp10[4], tmp10[3], tmp10[2], tmp10[1], tmp10[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/analysis/verilog_output/simple_multiplier_Float32.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[31:0] a;
|
| 9 |
+
input[31:0] b;
|
| 10 |
+
output[31:0] result;
|
| 11 |
+
|
| 12 |
+
wire[63:0] tmp14;
|
| 13 |
+
wire[31:0] tmp15;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp15;
|
| 17 |
+
assign tmp14 = a * b;
|
| 18 |
+
assign tmp15 = {tmp14[31], tmp14[30], tmp14[29], tmp14[28], tmp14[27], tmp14[26], tmp14[25], tmp14[24], tmp14[23], tmp14[22], tmp14[21], tmp14[20], tmp14[19], tmp14[18], tmp14[17], tmp14[16], tmp14[15], tmp14[14], tmp14[13], tmp14[12], tmp14[11], tmp14[10], tmp14[9], tmp14[8], tmp14[7], tmp14[6], tmp14[5], tmp14[4], tmp14[3], tmp14[2], tmp14[1], tmp14[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/analysis/verilog_output/simple_multiplier_Float8.v
ADDED
|
@@ -0,0 +1,21 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
// Generated automatically via PyRTL
|
| 2 |
+
// As one initial test of synthesis, map to FPGA with:
|
| 3 |
+
// yosys -p "synth_xilinx -top toplevel" thisfile.v
|
| 4 |
+
|
| 5 |
+
module toplevel(clk, rst, a, b, result);
|
| 6 |
+
input clk;
|
| 7 |
+
input rst;
|
| 8 |
+
input[7:0] a;
|
| 9 |
+
input[7:0] b;
|
| 10 |
+
output[7:0] result;
|
| 11 |
+
|
| 12 |
+
wire[15:0] tmp8;
|
| 13 |
+
wire[7:0] tmp9;
|
| 14 |
+
|
| 15 |
+
// Combinational
|
| 16 |
+
assign result = tmp9;
|
| 17 |
+
assign tmp8 = a * b;
|
| 18 |
+
assign tmp9 = {tmp8[7], tmp8[6], tmp8[5], tmp8[4], tmp8[3], tmp8[2], tmp8[1], tmp8[0]};
|
| 19 |
+
|
| 20 |
+
endmodule
|
| 21 |
+
|
hardware_accelerators/app.py
ADDED
|
@@ -0,0 +1,388 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
from typing import Literal
|
| 3 |
+
import gradio as gr
|
| 4 |
+
from gradio.components.image_editor import EditorValue
|
| 5 |
+
import numpy as np
|
| 6 |
+
import pandas as pd
|
| 7 |
+
from PIL import Image
|
| 8 |
+
from torchvision import transforms
|
| 9 |
+
from .nn.util import load_model
|
| 10 |
+
from .rtllib.lmul import lmul_simple
|
| 11 |
+
from .rtllib.multipliers import float_multiplier
|
| 12 |
+
from .dtypes import Float8, BF16
|
| 13 |
+
from .rtllib import (
|
| 14 |
+
CompiledAcceleratorConfig,
|
| 15 |
+
)
|
| 16 |
+
from .simulation import CompiledAcceleratorSimulator
|
| 17 |
+
from .analysis.hardware_stats import (
|
| 18 |
+
calculate_hardware_stats,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
__all__ = ["create_app"]
|
| 22 |
+
|
| 23 |
+
# ------------ CONSTANTS ------------ #
|
| 24 |
+
|
| 25 |
+
# Load the component data
|
| 26 |
+
data_path = os.environ.get("COMPONENT_DATA_PATH", "results/component_data.csv")
|
| 27 |
+
DF = pd.read_csv(data_path)
|
| 28 |
+
|
| 29 |
+
# Load the trained model
|
| 30 |
+
MODEL = load_model("models/mlp_mnist.pth", "cpu") # type: ignore
|
| 31 |
+
MODEL.eval()
|
| 32 |
+
|
| 33 |
+
classes = [
|
| 34 |
+
"zero",
|
| 35 |
+
"one",
|
| 36 |
+
"two",
|
| 37 |
+
"three",
|
| 38 |
+
"four",
|
| 39 |
+
"five",
|
| 40 |
+
"six",
|
| 41 |
+
"seven",
|
| 42 |
+
"eight",
|
| 43 |
+
"nine",
|
| 44 |
+
]
|
| 45 |
+
labels_value = {label: 0.0 for label in classes}
|
| 46 |
+
|
| 47 |
+
accelerator_dtypes = ["float8", "bfloat16", "float32"]
|
| 48 |
+
dtype_map = {
|
| 49 |
+
"float8": Float8,
|
| 50 |
+
"bfloat16": BF16,
|
| 51 |
+
"float32": BF16, # TODO: use Float32, but not right now because its slow
|
| 52 |
+
}
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
mult_map = {
|
| 56 |
+
"IEEE 754": float_multiplier,
|
| 57 |
+
"l-mul": lmul_simple,
|
| 58 |
+
}
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
# ------------ Event Listener Functions ------------ #
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def filter_activation_types(weight_type: str, activation_type: str):
|
| 65 |
+
if weight_type == "float8":
|
| 66 |
+
return gr.update(choices=accelerator_dtypes)
|
| 67 |
+
elif weight_type == "bfloat16":
|
| 68 |
+
if activation_type == "float8":
|
| 69 |
+
activation_type = "bfloat16"
|
| 70 |
+
return gr.update(value=activation_type, choices=["bfloat16", "float32"])
|
| 71 |
+
elif weight_type == "float32":
|
| 72 |
+
if activation_type != "float32":
|
| 73 |
+
activation_type = "float32"
|
| 74 |
+
return gr.update(value=activation_type, choices=["float32"])
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def warn_w8a8(weight_type: str, activation_type: str):
|
| 78 |
+
if weight_type == "float8" and activation_type == "float8":
|
| 79 |
+
gr.Warning(
|
| 80 |
+
"W8A8 has poor performance without quantization, which is not yet supported in simulation. Theoretical results are still calculated for FP8 hardware",
|
| 81 |
+
duration=5,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
def image_to_tensor(sketchpad: EditorValue):
|
| 86 |
+
image = sketchpad["composite"]
|
| 87 |
+
image = image.resize((28, 28), Image.Resampling.LANCZOS) # type: ignore
|
| 88 |
+
img_array = np.transpose(np.array(image), (2, 0, 1))[-1]
|
| 89 |
+
|
| 90 |
+
# Preprocessing: convert image to tensor and normalize
|
| 91 |
+
transform = transforms.Compose(
|
| 92 |
+
[
|
| 93 |
+
transforms.ToTensor(),
|
| 94 |
+
transforms.Normalize((0.1307,), (0.3081,)),
|
| 95 |
+
]
|
| 96 |
+
)
|
| 97 |
+
tensor_image = transform(img_array)
|
| 98 |
+
return tensor_image
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def calculate_stats(
|
| 102 |
+
activation_type: Literal["float8", "bfloat16", "float32"],
|
| 103 |
+
weight_type: Literal["float8", "bfloat16", "float32"],
|
| 104 |
+
systolic_array_size: int,
|
| 105 |
+
num_accelerator_cores: int,
|
| 106 |
+
fast_internals: Literal["Fast", "Efficient"],
|
| 107 |
+
pipeline_level: Literal["None", "Low", "Full"],
|
| 108 |
+
process_node_size: Literal["7nm", "45nm", "130nm"],
|
| 109 |
+
):
|
| 110 |
+
"""
|
| 111 |
+
Calculate hardware statistics for both lmul and standard IEEE multiplier configurations.
|
| 112 |
+
|
| 113 |
+
Args:
|
| 114 |
+
activation_type: Type of activations
|
| 115 |
+
weight_type: Type of weights
|
| 116 |
+
systolic_array_size: Size of the systolic array
|
| 117 |
+
num_accelerator_cores: Number of accelerator cores
|
| 118 |
+
fast_internals: Whether to use fast or efficient components
|
| 119 |
+
pipeline_level: Level of pipelining
|
| 120 |
+
process_node_size: Process node size (ignored for now)
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Tuple of (lmul_metrics, ieee_metrics, comparison_metrics) dictionaries
|
| 124 |
+
"""
|
| 125 |
+
stat_map = {
|
| 126 |
+
"float8": "fp8",
|
| 127 |
+
"bfloat16": "bf16",
|
| 128 |
+
"float32": "fp32",
|
| 129 |
+
"Fast": True,
|
| 130 |
+
"Efficient": False,
|
| 131 |
+
"None": 0,
|
| 132 |
+
"Low": 1,
|
| 133 |
+
"None": "combinational",
|
| 134 |
+
"Low": "combinational",
|
| 135 |
+
"Full": "pipelined",
|
| 136 |
+
"7nm": 7,
|
| 137 |
+
"45nm": 45,
|
| 138 |
+
"130nm": 130,
|
| 139 |
+
}
|
| 140 |
+
# Calculate hardware stats using the functions from hardware_stats.py
|
| 141 |
+
lmul_metrics, ieee_metrics = calculate_hardware_stats(
|
| 142 |
+
DF,
|
| 143 |
+
activation_type,
|
| 144 |
+
weight_type,
|
| 145 |
+
systolic_array_size,
|
| 146 |
+
num_accelerator_cores,
|
| 147 |
+
fast_internals,
|
| 148 |
+
pipeline_level,
|
| 149 |
+
process_node_size,
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
# comparison_metrics = calculate_comparison_metrics(lmul_metrics, ieee_metrics)
|
| 153 |
+
|
| 154 |
+
# Format the metrics for display in the Gradio UI
|
| 155 |
+
lmul_html = "<div style='text-align: left;'>"
|
| 156 |
+
for key, value in lmul_metrics.items():
|
| 157 |
+
lmul_html += f"<p><b>{key}:</b> {value}</p>"
|
| 158 |
+
lmul_html += "</div>"
|
| 159 |
+
|
| 160 |
+
ieee_html = "<div style='text-align: left;'>"
|
| 161 |
+
for key, value in ieee_metrics.items():
|
| 162 |
+
ieee_html += f"<p><b>{key}:</b> {value}</p>"
|
| 163 |
+
ieee_html += "</div>"
|
| 164 |
+
|
| 165 |
+
# comparison_html = "<div style='text-align: left;'>"
|
| 166 |
+
# comparison_html += "<h3>Comparison (lmul vs IEEE)</h3>"
|
| 167 |
+
# for key, value in comparison_metrics.items():
|
| 168 |
+
# comparison_html += f"<p><b>{key}:</b> {value}</p>"
|
| 169 |
+
# comparison_html += "</div>"
|
| 170 |
+
|
| 171 |
+
return (
|
| 172 |
+
lmul_html,
|
| 173 |
+
ieee_html,
|
| 174 |
+
# comparison_html,
|
| 175 |
+
)
|
| 176 |
+
|
| 177 |
+
|
| 178 |
+
def predict_lmul(
|
| 179 |
+
sketchpad: EditorValue,
|
| 180 |
+
weight: str,
|
| 181 |
+
activation: str,
|
| 182 |
+
gr_progress=gr.Progress(track_tqdm=True),
|
| 183 |
+
):
|
| 184 |
+
if weight == "float8" and activation == "float8":
|
| 185 |
+
activation = "bfloat16"
|
| 186 |
+
config = CompiledAcceleratorConfig(
|
| 187 |
+
array_size=8,
|
| 188 |
+
activation_type=dtype_map[activation],
|
| 189 |
+
weight_type=dtype_map[weight],
|
| 190 |
+
multiplier=lmul_simple,
|
| 191 |
+
)
|
| 192 |
+
sim = CompiledAcceleratorSimulator(config, MODEL)
|
| 193 |
+
|
| 194 |
+
x = image_to_tensor(sketchpad).detach().numpy().flatten()
|
| 195 |
+
probabilities = sim.predict(x)
|
| 196 |
+
return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def predict_ieee(
|
| 200 |
+
sketchpad: EditorValue,
|
| 201 |
+
weight: str,
|
| 202 |
+
activation: str,
|
| 203 |
+
gr_progress=gr.Progress(track_tqdm=True),
|
| 204 |
+
):
|
| 205 |
+
if weight == "float8" and activation == "float8":
|
| 206 |
+
activation = "bfloat16"
|
| 207 |
+
config = CompiledAcceleratorConfig(
|
| 208 |
+
array_size=8,
|
| 209 |
+
activation_type=dtype_map[activation],
|
| 210 |
+
weight_type=dtype_map[weight],
|
| 211 |
+
multiplier=float_multiplier,
|
| 212 |
+
)
|
| 213 |
+
simulator = CompiledAcceleratorSimulator(config, MODEL)
|
| 214 |
+
|
| 215 |
+
x = image_to_tensor(sketchpad).detach().numpy().flatten()
|
| 216 |
+
probabilities = simulator.predict(x)
|
| 217 |
+
return {cls: float(prob) for cls, prob in zip(classes, probabilities)}
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
# ------------ Blocks UI Layout ------------ #
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def create_app():
|
| 224 |
+
with gr.Blocks(fill_height=False, fill_width=False, title=__file__) as demo:
|
| 225 |
+
|
| 226 |
+
gr.Markdown("## Draw a digit to see the model's prediction")
|
| 227 |
+
with gr.Row(equal_height=False):
|
| 228 |
+
with gr.Column(scale=3):
|
| 229 |
+
canvas_size = (400, 400)
|
| 230 |
+
sketchpad = gr.Sketchpad(
|
| 231 |
+
# label="Draw a digit",
|
| 232 |
+
type="pil", # Changed to PIL
|
| 233 |
+
transforms=(),
|
| 234 |
+
layers=False,
|
| 235 |
+
canvas_size=canvas_size,
|
| 236 |
+
# scale=2,
|
| 237 |
+
container=False,
|
| 238 |
+
)
|
| 239 |
+
predict_btn = gr.Button(
|
| 240 |
+
"Run Hardware Simulation",
|
| 241 |
+
variant="primary",
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
# with gr.Accordion("Accelerator Configuration", open=True):
|
| 245 |
+
with gr.Group():
|
| 246 |
+
with gr.Row(): # Weight and activation types
|
| 247 |
+
weight_type_component = gr.Radio(
|
| 248 |
+
label="Weights d-type",
|
| 249 |
+
choices=accelerator_dtypes,
|
| 250 |
+
value="float8",
|
| 251 |
+
interactive=True,
|
| 252 |
+
)
|
| 253 |
+
activation_type_component = gr.Radio(
|
| 254 |
+
label="Activations d-type",
|
| 255 |
+
choices=accelerator_dtypes,
|
| 256 |
+
value="bfloat16",
|
| 257 |
+
interactive=True,
|
| 258 |
+
)
|
| 259 |
+
# Prevent w8a8 from being selected, or any other combination where act < weight
|
| 260 |
+
weight_type_component.select(
|
| 261 |
+
fn=filter_activation_types,
|
| 262 |
+
inputs=[weight_type_component, activation_type_component],
|
| 263 |
+
outputs=activation_type_component,
|
| 264 |
+
)
|
| 265 |
+
gr.on(
|
| 266 |
+
triggers=[
|
| 267 |
+
weight_type_component.select,
|
| 268 |
+
activation_type_component.select,
|
| 269 |
+
],
|
| 270 |
+
fn=warn_w8a8,
|
| 271 |
+
inputs=[weight_type_component, activation_type_component],
|
| 272 |
+
)
|
| 273 |
+
with gr.Row():
|
| 274 |
+
systolic_array_size_component = gr.Slider(
|
| 275 |
+
label="Systolic Array Size",
|
| 276 |
+
info="Dimensions of the matrix acceleration unit",
|
| 277 |
+
minimum=4,
|
| 278 |
+
maximum=512,
|
| 279 |
+
step=1,
|
| 280 |
+
value=16,
|
| 281 |
+
interactive=True,
|
| 282 |
+
)
|
| 283 |
+
num_accelerator_cores_component = gr.Number(
|
| 284 |
+
label="Number of Accelerator Cores",
|
| 285 |
+
info="Total number of accelerator units per chip",
|
| 286 |
+
minimum=1,
|
| 287 |
+
maximum=1024,
|
| 288 |
+
step=1,
|
| 289 |
+
value=1,
|
| 290 |
+
interactive=True,
|
| 291 |
+
)
|
| 292 |
+
with gr.Row(equal_height=True):
|
| 293 |
+
fast_internals_component = gr.Dropdown(
|
| 294 |
+
label="Internal Component Type",
|
| 295 |
+
info="Configure the lowest level hardware units to use a faster or more efficient design.",
|
| 296 |
+
choices=["Fast", "Efficient"],
|
| 297 |
+
value="Fast",
|
| 298 |
+
interactive=True,
|
| 299 |
+
filterable=False,
|
| 300 |
+
)
|
| 301 |
+
pipeline_level_component = gr.Dropdown(
|
| 302 |
+
label="Pipeline Level",
|
| 303 |
+
info="Configure the pipeline level of processing elements within the accelerator. Low uses a single register between multipliers and adders. Full uses pipelined individual components.",
|
| 304 |
+
choices=["None", "Low", "Full"],
|
| 305 |
+
value="Full",
|
| 306 |
+
interactive=True,
|
| 307 |
+
filterable=False,
|
| 308 |
+
)
|
| 309 |
+
process_node_size_component = gr.Radio(
|
| 310 |
+
label="Process Node Size",
|
| 311 |
+
info="Configure the process node size of the hardware units. Smaller nodes are faster and use less area.",
|
| 312 |
+
choices=["7nm", "45nm", "130nm"],
|
| 313 |
+
value="45nm",
|
| 314 |
+
interactive=False,
|
| 315 |
+
)
|
| 316 |
+
|
| 317 |
+
with gr.Column(scale=2):
|
| 318 |
+
lmul_predictions = gr.Label(
|
| 319 |
+
label="l-mul Simulator Predictions",
|
| 320 |
+
value=labels_value,
|
| 321 |
+
min_width=100,
|
| 322 |
+
)
|
| 323 |
+
|
| 324 |
+
lmul_html = gr.HTML(
|
| 325 |
+
label="L-mul Hardware Stats",
|
| 326 |
+
show_label=True,
|
| 327 |
+
container=True,
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
with gr.Column(scale=2):
|
| 331 |
+
ieee_predictions = gr.Label(
|
| 332 |
+
label="Hardware Simulator Predictions",
|
| 333 |
+
value=labels_value,
|
| 334 |
+
min_width=100,
|
| 335 |
+
)
|
| 336 |
+
ieee_html = gr.HTML(
|
| 337 |
+
label="IEEE Multiplier Hardware Stats",
|
| 338 |
+
show_label=True,
|
| 339 |
+
container=True,
|
| 340 |
+
)
|
| 341 |
+
|
| 342 |
+
# ------------ Event Listeners ------------ #
|
| 343 |
+
|
| 344 |
+
predict_btn.click(
|
| 345 |
+
fn=predict_ieee,
|
| 346 |
+
inputs=[sketchpad, weight_type_component, activation_type_component],
|
| 347 |
+
outputs=ieee_predictions,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
# TODO: implement simulator_predict
|
| 351 |
+
predict_btn.click(
|
| 352 |
+
fn=predict_lmul,
|
| 353 |
+
inputs=[sketchpad, weight_type_component, activation_type_component],
|
| 354 |
+
outputs=lmul_predictions,
|
| 355 |
+
)
|
| 356 |
+
|
| 357 |
+
gr.on(
|
| 358 |
+
triggers=[
|
| 359 |
+
demo.load,
|
| 360 |
+
activation_type_component.change,
|
| 361 |
+
weight_type_component.change,
|
| 362 |
+
systolic_array_size_component.change,
|
| 363 |
+
num_accelerator_cores_component.change,
|
| 364 |
+
fast_internals_component.change,
|
| 365 |
+
pipeline_level_component.change,
|
| 366 |
+
process_node_size_component.change,
|
| 367 |
+
],
|
| 368 |
+
fn=calculate_stats,
|
| 369 |
+
inputs=[
|
| 370 |
+
activation_type_component,
|
| 371 |
+
weight_type_component,
|
| 372 |
+
systolic_array_size_component,
|
| 373 |
+
num_accelerator_cores_component,
|
| 374 |
+
fast_internals_component,
|
| 375 |
+
pipeline_level_component,
|
| 376 |
+
process_node_size_component,
|
| 377 |
+
],
|
| 378 |
+
outputs=[lmul_html, ieee_html],
|
| 379 |
+
show_progress="hidden",
|
| 380 |
+
)
|
| 381 |
+
|
| 382 |
+
return demo
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
if __name__ == "__main__":
|
| 386 |
+
demo = create_app()
|
| 387 |
+
demo.queue()
|
| 388 |
+
demo.launch(share=False)
|
hardware_accelerators/compile.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import time
|
| 3 |
+
import multiprocessing
|
| 4 |
+
import pyrtl
|
| 5 |
+
from tqdm import tqdm
|
| 6 |
+
from functools import partial
|
| 7 |
+
|
| 8 |
+
from .simulation import CompiledAcceleratorSimulator
|
| 9 |
+
|
| 10 |
+
from .rtllib import float_multiplier, lmul_simple
|
| 11 |
+
from .rtllib.accelerator import CompiledAcceleratorConfig
|
| 12 |
+
from .dtypes import BaseFloat, Float32, Float16, BF16, Float8
|
| 13 |
+
from typing import Iterator, Type, List, Callable
|
| 14 |
+
from itertools import product
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
def generate_accelerator_configs(
|
| 18 |
+
array_size: int = 8,
|
| 19 |
+
dtypes: List[Type[BaseFloat]] | None = None,
|
| 20 |
+
multipliers: List[Callable] | None = None,
|
| 21 |
+
**kwargs,
|
| 22 |
+
) -> Iterator[CompiledAcceleratorConfig]:
|
| 23 |
+
"""
|
| 24 |
+
Generate all valid CompiledAcceleratorConfig combinations.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
array_size: Size of the systolic array
|
| 28 |
+
dtypes: List of data types to consider. Defaults to [Float8, BF16, FP16, FP32]
|
| 29 |
+
multipliers: List of multiplier functions. Defaults to [float_multiplier, lmul]
|
| 30 |
+
|
| 31 |
+
Yields:
|
| 32 |
+
Valid CompiledAcceleratorConfig objects
|
| 33 |
+
|
| 34 |
+
Restrictions:
|
| 35 |
+
1. The activation_type must be greater than or equal to the weight_type in terms of bitwidth.
|
| 36 |
+
2. 16-bit float types (BF16, FP16) should not be combined with each other.
|
| 37 |
+
They should only pair with themselves or with FP32.
|
| 38 |
+
"""
|
| 39 |
+
if dtypes is None:
|
| 40 |
+
dtypes = [Float8, BF16, Float32]
|
| 41 |
+
|
| 42 |
+
if multipliers is None:
|
| 43 |
+
multipliers = [float_multiplier, lmul_simple]
|
| 44 |
+
|
| 45 |
+
# Sort dtypes by bitwidth for easier comparison
|
| 46 |
+
dtype_bitwidths = {dtype: dtype.bitwidth() for dtype in dtypes}
|
| 47 |
+
sorted_dtypes = sorted(dtypes, key=lambda d: dtype_bitwidths[d])
|
| 48 |
+
|
| 49 |
+
# Identify 16-bit float types
|
| 50 |
+
bit16_float_types = [dtype for dtype in dtypes if dtype_bitwidths[dtype] == 16]
|
| 51 |
+
|
| 52 |
+
# Generate all combinations
|
| 53 |
+
for multiplier in multipliers:
|
| 54 |
+
for weight_type in sorted_dtypes:
|
| 55 |
+
# Find valid activation types based on bitwidth
|
| 56 |
+
valid_activation_types = [
|
| 57 |
+
dtype
|
| 58 |
+
for dtype in sorted_dtypes
|
| 59 |
+
if dtype_bitwidths[dtype] >= dtype_bitwidths[weight_type]
|
| 60 |
+
]
|
| 61 |
+
|
| 62 |
+
for activation_type in valid_activation_types:
|
| 63 |
+
# Skip invalid combinations of 16-bit float types
|
| 64 |
+
if (
|
| 65 |
+
weight_type in bit16_float_types
|
| 66 |
+
and activation_type in bit16_float_types
|
| 67 |
+
and weight_type != activation_type
|
| 68 |
+
):
|
| 69 |
+
continue
|
| 70 |
+
|
| 71 |
+
yield CompiledAcceleratorConfig(
|
| 72 |
+
array_size=array_size,
|
| 73 |
+
activation_type=activation_type,
|
| 74 |
+
weight_type=weight_type,
|
| 75 |
+
multiplier=multiplier,
|
| 76 |
+
**kwargs,
|
| 77 |
+
)
|
| 78 |
+
|
| 79 |
+
|
| 80 |
+
def compile_and_save_simulator(config):
|
| 81 |
+
"""Compile and save a simulator for a given configuration.
|
| 82 |
+
|
| 83 |
+
Args:
|
| 84 |
+
config: The CompiledAcceleratorConfig to use
|
| 85 |
+
|
| 86 |
+
Returns:
|
| 87 |
+
Tuple of (config, success, time_taken)
|
| 88 |
+
"""
|
| 89 |
+
start_time = time.time()
|
| 90 |
+
|
| 91 |
+
try:
|
| 92 |
+
# Create the simulator
|
| 93 |
+
with pyrtl.temp_working_block():
|
| 94 |
+
CompiledAcceleratorSimulator(config)
|
| 95 |
+
|
| 96 |
+
end_time = time.time()
|
| 97 |
+
return (config, True, end_time - start_time)
|
| 98 |
+
|
| 99 |
+
except Exception as e:
|
| 100 |
+
end_time = time.time()
|
| 101 |
+
print(f"Error compiling {config}: {str(e)}")
|
| 102 |
+
return (config, False, end_time - start_time)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def compile_all_simulators(configs=None, max_workers=None):
|
| 106 |
+
"""Compile and save simulators for all configurations using multiprocessing.
|
| 107 |
+
|
| 108 |
+
Args:
|
| 109 |
+
configs: List of configurations to compile. If None, generates all valid configs.
|
| 110 |
+
base_dir: Base directory to save simulations
|
| 111 |
+
max_workers: Maximum number of worker processes. If None, uses CPU count.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
List of results (config, success, time_taken)
|
| 115 |
+
"""
|
| 116 |
+
if configs is None:
|
| 117 |
+
configs = list(generate_accelerator_configs())
|
| 118 |
+
|
| 119 |
+
if max_workers is None:
|
| 120 |
+
max_workers = multiprocessing.cpu_count()
|
| 121 |
+
|
| 122 |
+
print(f"Compiling {len(configs)} configurations using {max_workers} workers")
|
| 123 |
+
|
| 124 |
+
# Create a partial function with the base_dir parameter
|
| 125 |
+
compile_func = partial(compile_and_save_simulator)
|
| 126 |
+
|
| 127 |
+
# Use multiprocessing to compile all configurations
|
| 128 |
+
with multiprocessing.Pool(processes=max_workers) as pool:
|
| 129 |
+
# Use tqdm to show progress
|
| 130 |
+
results = list(
|
| 131 |
+
tqdm(
|
| 132 |
+
pool.imap(compile_func, configs),
|
| 133 |
+
total=len(configs),
|
| 134 |
+
desc="Compiling simulators",
|
| 135 |
+
)
|
| 136 |
+
)
|
| 137 |
+
|
| 138 |
+
# Print summary
|
| 139 |
+
successful = [r for r in results if r[1]]
|
| 140 |
+
failed = [r for r in results if not r[1]]
|
| 141 |
+
|
| 142 |
+
print(f"\nCompilation complete:")
|
| 143 |
+
print(f" Total: {len(results)}")
|
| 144 |
+
print(f" Successful: {len(successful)}")
|
| 145 |
+
print(f" Failed: {len(failed)}")
|
| 146 |
+
|
| 147 |
+
if successful:
|
| 148 |
+
avg_time = sum(r[2] for r in successful) / len(successful)
|
| 149 |
+
print(f" Average compilation time: {avg_time:.2f} seconds")
|
| 150 |
+
|
| 151 |
+
return results
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
if __name__ == "__main__":
|
| 155 |
+
# Generate all valid configurations
|
| 156 |
+
all_configs = list(generate_accelerator_configs())
|
| 157 |
+
print(f"Generated {len(all_configs)} configs")
|
| 158 |
+
|
| 159 |
+
# Compile and save simulators for all configurations
|
| 160 |
+
results = compile_all_simulators(all_configs)
|
| 161 |
+
|
| 162 |
+
# Print details of failed compilations if any
|
| 163 |
+
failed = [r for r in results if not r[1]]
|
| 164 |
+
if failed:
|
| 165 |
+
print("\nFailed compilations:")
|
| 166 |
+
for config, _, _ in failed:
|
| 167 |
+
print(config.name)
|
hardware_accelerators/dtypes/__init__.py
CHANGED
|
@@ -2,5 +2,7 @@
|
|
| 2 |
from .base import BaseFloat
|
| 3 |
from .bfloat16 import BF16
|
| 4 |
from .float8 import Float8
|
|
|
|
|
|
|
| 5 |
|
| 6 |
-
__all__ = ["BaseFloat", "Float8", "BF16"]
|
|
|
|
| 2 |
from .base import BaseFloat
|
| 3 |
from .bfloat16 import BF16
|
| 4 |
from .float8 import Float8
|
| 5 |
+
from .float16 import Float16
|
| 6 |
+
from .float32 import Float32
|
| 7 |
|
| 8 |
+
__all__ = ["BaseFloat", "Float8", "BF16", "Float16", "Float32"]
|
hardware_accelerators/dtypes/base.py
CHANGED
|
@@ -56,6 +56,11 @@ class BaseFloat(ABC):
|
|
| 56 |
else:
|
| 57 |
raise ValueError("Must provide one of: value, binary, or binint")
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
@classmethod
|
| 60 |
@abstractmethod
|
| 61 |
def format_spec(cls) -> FormatSpec:
|
|
@@ -160,7 +165,6 @@ class BaseFloat(ABC):
|
|
| 160 |
|
| 161 |
def _format_binary_string(self, binary=None) -> str:
|
| 162 |
"""Format binary string with dots for readability"""
|
| 163 |
-
# Clean the input string first
|
| 164 |
if binary is None:
|
| 165 |
binary = self.binary
|
| 166 |
clean_binary = "".join(c for c in binary if c in "01")
|
|
@@ -169,8 +173,13 @@ class BaseFloat(ABC):
|
|
| 169 |
|
| 170 |
if self.bitwidth() == 8: # Float8
|
| 171 |
return f"{clean_binary[0]}.{clean_binary[1:5]}.{clean_binary[5:]}"
|
| 172 |
-
elif self.bitwidth() ==
|
| 173 |
-
return clean_binary
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 174 |
else:
|
| 175 |
return clean_binary
|
| 176 |
|
|
|
|
| 56 |
else:
|
| 57 |
raise ValueError("Must provide one of: value, binary, or binint")
|
| 58 |
|
| 59 |
+
@classmethod
|
| 60 |
+
@abstractmethod
|
| 61 |
+
def binary_max(cls) -> int:
|
| 62 |
+
pass
|
| 63 |
+
|
| 64 |
@classmethod
|
| 65 |
@abstractmethod
|
| 66 |
def format_spec(cls) -> FormatSpec:
|
|
|
|
| 165 |
|
| 166 |
def _format_binary_string(self, binary=None) -> str:
|
| 167 |
"""Format binary string with dots for readability"""
|
|
|
|
| 168 |
if binary is None:
|
| 169 |
binary = self.binary
|
| 170 |
clean_binary = "".join(c for c in binary if c in "01")
|
|
|
|
| 173 |
|
| 174 |
if self.bitwidth() == 8: # Float8
|
| 175 |
return f"{clean_binary[0]}.{clean_binary[1:5]}.{clean_binary[5:]}"
|
| 176 |
+
elif self.bitwidth() == 32: # Float32
|
| 177 |
+
return f"{clean_binary[0]}.{clean_binary[1:9]}.{clean_binary[9:]}"
|
| 178 |
+
elif self.bitwidth() == 16:
|
| 179 |
+
if self.__class__.__name__ == "Float16": # Float16
|
| 180 |
+
return f"{clean_binary[0]}.{clean_binary[1:6]}.{clean_binary[6:]}"
|
| 181 |
+
else: # BF16
|
| 182 |
+
return clean_binary
|
| 183 |
else:
|
| 184 |
return clean_binary
|
| 185 |
|
hardware_accelerators/dtypes/bfloat16.py
CHANGED
|
@@ -29,6 +29,10 @@ class BF16(BaseFloat):
|
|
| 29 |
min_subnormal=2**-126 * (1 / 128),
|
| 30 |
)
|
| 31 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 32 |
def _float32_to_bf16_parts(self, f32: float) -> Tuple[int, int, int]:
|
| 33 |
"""Convert float32 to BF16 parts (sign, exponent, mantissa)"""
|
| 34 |
# Get binary representation of float32
|
|
|
|
| 29 |
min_subnormal=2**-126 * (1 / 128),
|
| 30 |
)
|
| 31 |
|
| 32 |
+
@classmethod
|
| 33 |
+
def binary_max(cls) -> int:
|
| 34 |
+
return 0b0111111101111111
|
| 35 |
+
|
| 36 |
def _float32_to_bf16_parts(self, f32: float) -> Tuple[int, int, int]:
|
| 37 |
"""Convert float32 to BF16 parts (sign, exponent, mantissa)"""
|
| 38 |
# Get binary representation of float32
|
hardware_accelerators/dtypes/float16.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseFloat, FormatSpec
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Float16(BaseFloat):
|
| 5 |
+
"""
|
| 6 |
+
16-bit floating point number with IEEE 754 half-precision format
|
| 7 |
+
- 1 sign bit
|
| 8 |
+
- 5 exponent bits (bias 15)
|
| 9 |
+
- 10 mantissa bits
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
@classmethod
|
| 13 |
+
def format_spec(cls) -> FormatSpec:
|
| 14 |
+
return FormatSpec(
|
| 15 |
+
total_bits=16,
|
| 16 |
+
exponent_bits=5,
|
| 17 |
+
mantissa_bits=10,
|
| 18 |
+
bias=15,
|
| 19 |
+
max_normal=65504.0, # from 0.11110.1111111111
|
| 20 |
+
min_normal=2**-14, # from 0.00001.0000000000
|
| 21 |
+
max_subnormal=2**-14 * (1023 / 1024), # from 0.00000.1111111111
|
| 22 |
+
min_subnormal=2**-24, # from 0.00000.0000000001
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
@classmethod
|
| 26 |
+
def binary_max(cls) -> int:
|
| 27 |
+
return 0b0111101111111111
|
| 28 |
+
|
| 29 |
+
def _decimal_to_binary(self, num: float) -> str:
|
| 30 |
+
"""Convert decimal number to binary string in IEEE 754 format"""
|
| 31 |
+
if num == 0:
|
| 32 |
+
return "0.00000.0000000000"
|
| 33 |
+
|
| 34 |
+
# Extract sign bit
|
| 35 |
+
sign = "1" if num < 0 else "0"
|
| 36 |
+
num = abs(num)
|
| 37 |
+
|
| 38 |
+
# Handle NaN
|
| 39 |
+
if num != num: # Python's way to check for NaN
|
| 40 |
+
return sign + ".11111.1111111111"
|
| 41 |
+
|
| 42 |
+
# Clamp to max value if overflow
|
| 43 |
+
if num > self.max_normal():
|
| 44 |
+
return "0.11110.1111111111" if sign == "0" else "1.11110.1111111111"
|
| 45 |
+
|
| 46 |
+
# Find exponent and normalized mantissa
|
| 47 |
+
exp = 0
|
| 48 |
+
temp = num
|
| 49 |
+
|
| 50 |
+
# Handle normal numbers
|
| 51 |
+
while temp >= 2 and exp < 31:
|
| 52 |
+
temp /= 2
|
| 53 |
+
exp += 1
|
| 54 |
+
while temp < 1 and exp > -14:
|
| 55 |
+
temp *= 2
|
| 56 |
+
exp -= 1
|
| 57 |
+
|
| 58 |
+
# Handle subnormal numbers
|
| 59 |
+
if exp <= -14:
|
| 60 |
+
# Shift mantissa right and adjust
|
| 61 |
+
shift = -14 - exp
|
| 62 |
+
temp /= 2**shift
|
| 63 |
+
exp = -14
|
| 64 |
+
|
| 65 |
+
# Calculate biased exponent
|
| 66 |
+
if temp < 1: # Subnormal
|
| 67 |
+
biased_exp = "00000"
|
| 68 |
+
else: # Normal
|
| 69 |
+
biased_exp = format(exp + self.bias(), "05b")
|
| 70 |
+
|
| 71 |
+
# Calculate mantissa bits
|
| 72 |
+
if temp < 1: # Subnormal
|
| 73 |
+
mantissa_value = int(temp * (2 ** self.mantissa_bits()))
|
| 74 |
+
else: # Normal
|
| 75 |
+
mantissa_value = int((temp - 1) * (2 ** self.mantissa_bits()))
|
| 76 |
+
|
| 77 |
+
mantissa = format(mantissa_value, f"0{self.mantissa_bits()}b")
|
| 78 |
+
|
| 79 |
+
return f"{sign}.{biased_exp}.{mantissa}"
|
| 80 |
+
|
| 81 |
+
def _binary_to_decimal(self, binary: str) -> float:
|
| 82 |
+
"""Convert binary string in IEEE 754 format to decimal"""
|
| 83 |
+
# Clean up binary string
|
| 84 |
+
binary = "".join(c for c in binary if c in "01")
|
| 85 |
+
|
| 86 |
+
# Extract components
|
| 87 |
+
sign = -1 if binary[0] == "1" else 1
|
| 88 |
+
exp = binary[1:6]
|
| 89 |
+
mantissa = binary[6:]
|
| 90 |
+
|
| 91 |
+
# Handle special cases
|
| 92 |
+
if exp == "11111" and mantissa == "1111111111": # NaN representation
|
| 93 |
+
return float("nan")
|
| 94 |
+
if exp == "00000" and mantissa == "0000000000":
|
| 95 |
+
return 0.0
|
| 96 |
+
|
| 97 |
+
# Convert biased exponent
|
| 98 |
+
biased_exp = int(exp, 2)
|
| 99 |
+
|
| 100 |
+
if biased_exp == 0: # Subnormal number
|
| 101 |
+
actual_exp = -14
|
| 102 |
+
mantissa_value = int(mantissa, 2) / (2 ** self.mantissa_bits())
|
| 103 |
+
return sign * (2**actual_exp) * mantissa_value
|
| 104 |
+
else: # Normal number
|
| 105 |
+
actual_exp = biased_exp - self.bias()
|
| 106 |
+
mantissa_value = 1 + int(mantissa, 2) / (2 ** self.mantissa_bits())
|
| 107 |
+
return sign * (2**actual_exp) * mantissa_value
|
| 108 |
+
|
| 109 |
+
@classmethod
|
| 110 |
+
def from_bits(cls, bits: int) -> "Float16":
|
| 111 |
+
"""Create Float16 from 16-bit integer"""
|
| 112 |
+
return cls(binint=bits)
|
| 113 |
+
|
| 114 |
+
@classmethod
|
| 115 |
+
def nan(cls) -> "Float16":
|
| 116 |
+
"""Create NaN value"""
|
| 117 |
+
return cls(binary="1.11111.1111111111")
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def max_value(cls) -> "Float16":
|
| 121 |
+
"""Create maximum representable value"""
|
| 122 |
+
return cls(binary="0.11110.1111111111")
|
| 123 |
+
|
| 124 |
+
@classmethod
|
| 125 |
+
def min_value(cls) -> "Float16":
|
| 126 |
+
"""Create minimum representable normal value"""
|
| 127 |
+
return cls(binary="0.00001.0000000000")
|
| 128 |
+
|
| 129 |
+
@classmethod
|
| 130 |
+
def min_subnormal(cls) -> "Float16":
|
| 131 |
+
"""Create minimum representable subnormal value"""
|
| 132 |
+
return cls(binary="0.00000.0000000001")
|
| 133 |
+
|
| 134 |
+
def detailed_breakdown(self) -> dict:
|
| 135 |
+
"""Provide detailed breakdown of the Float16 number components"""
|
| 136 |
+
binary = "".join(c for c in self.binary if c in "01")
|
| 137 |
+
sign_bit = int(binary[0])
|
| 138 |
+
exp_bits = binary[1:6]
|
| 139 |
+
mantissa_bits = binary[6:]
|
| 140 |
+
|
| 141 |
+
exp_val = int(exp_bits, 2)
|
| 142 |
+
mantissa_val = int(mantissa_bits, 2)
|
| 143 |
+
|
| 144 |
+
is_normal = exp_val != 0 and exp_val != 31
|
| 145 |
+
is_subnormal = exp_val == 0 and mantissa_val != 0
|
| 146 |
+
is_zero = exp_val == 0 and mantissa_val == 0
|
| 147 |
+
is_nan = (
|
| 148 |
+
exp_val == 31 and mantissa_val == 1023
|
| 149 |
+
) # Only s.11111.1111111111 is NaN
|
| 150 |
+
|
| 151 |
+
return {
|
| 152 |
+
"binary": self.binary,
|
| 153 |
+
"sign": sign_bit,
|
| 154 |
+
"exponent_bits": exp_bits,
|
| 155 |
+
"exponent_value": (exp_val - self.bias() if exp_val != 0 else "subnormal"),
|
| 156 |
+
"mantissa_bits": mantissa_bits,
|
| 157 |
+
"mantissa_value": mantissa_val,
|
| 158 |
+
"decimal_approx": self.decimal_approx,
|
| 159 |
+
"original_value": self.original_value,
|
| 160 |
+
"is_normal": is_normal,
|
| 161 |
+
"is_subnormal": is_subnormal,
|
| 162 |
+
"is_zero": is_zero,
|
| 163 |
+
"is_nan": is_nan,
|
| 164 |
+
"normalized_value": (
|
| 165 |
+
(1 + mantissa_val / 1024) if is_normal else (mantissa_val / 1024)
|
| 166 |
+
),
|
| 167 |
+
}
|
hardware_accelerators/dtypes/float32.py
ADDED
|
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .base import BaseFloat, FormatSpec
|
| 2 |
+
|
| 3 |
+
|
| 4 |
+
class Float32(BaseFloat):
|
| 5 |
+
"""
|
| 6 |
+
32-bit floating point number with IEEE 754 single-precision format
|
| 7 |
+
- 1 sign bit
|
| 8 |
+
- 8 exponent bits (bias 127)
|
| 9 |
+
- 23 mantissa bits
|
| 10 |
+
"""
|
| 11 |
+
|
| 12 |
+
@classmethod
|
| 13 |
+
def format_spec(cls) -> FormatSpec:
|
| 14 |
+
return FormatSpec(
|
| 15 |
+
total_bits=32,
|
| 16 |
+
exponent_bits=8,
|
| 17 |
+
mantissa_bits=23,
|
| 18 |
+
bias=127,
|
| 19 |
+
max_normal=3.4028235e38, # from 0.11111110.11111111111111111111111
|
| 20 |
+
min_normal=2**-126, # from 0.00000001.00000000000000000000000
|
| 21 |
+
max_subnormal=2**-126
|
| 22 |
+
* (8388607 / 8388608), # from 0.00000000.11111111111111111111111
|
| 23 |
+
min_subnormal=2**-149, # from 0.00000000.00000000000000000000001
|
| 24 |
+
)
|
| 25 |
+
|
| 26 |
+
@classmethod
|
| 27 |
+
def binary_max(cls) -> int:
|
| 28 |
+
return 0b01111111011111111111111111111111
|
| 29 |
+
|
| 30 |
+
def _decimal_to_binary(self, num: float) -> str:
|
| 31 |
+
"""Convert decimal number to binary string in IEEE 754 format"""
|
| 32 |
+
if num == 0:
|
| 33 |
+
return "0.00000000.00000000000000000000000"
|
| 34 |
+
|
| 35 |
+
# Extract sign bit
|
| 36 |
+
sign = "1" if num < 0 else "0"
|
| 37 |
+
num = abs(num)
|
| 38 |
+
|
| 39 |
+
# Handle NaN
|
| 40 |
+
if num != num: # Python's way to check for NaN
|
| 41 |
+
return sign + ".11111111.11111111111111111111111"
|
| 42 |
+
|
| 43 |
+
# Clamp to max value if overflow
|
| 44 |
+
if num > self.max_normal():
|
| 45 |
+
return (
|
| 46 |
+
"0.11111110.11111111111111111111111"
|
| 47 |
+
if sign == "0"
|
| 48 |
+
else "1.11111110.11111111111111111111111"
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
# Find exponent and normalized mantissa
|
| 52 |
+
exp = 0
|
| 53 |
+
temp = num
|
| 54 |
+
|
| 55 |
+
# Handle normal numbers
|
| 56 |
+
while temp >= 2 and exp < 255:
|
| 57 |
+
temp /= 2
|
| 58 |
+
exp += 1
|
| 59 |
+
while temp < 1 and exp > -126:
|
| 60 |
+
temp *= 2
|
| 61 |
+
exp -= 1
|
| 62 |
+
|
| 63 |
+
# Handle subnormal numbers
|
| 64 |
+
if exp <= -126:
|
| 65 |
+
# Shift mantissa right and adjust
|
| 66 |
+
shift = -126 - exp
|
| 67 |
+
temp /= 2**shift
|
| 68 |
+
exp = -126
|
| 69 |
+
|
| 70 |
+
# Calculate biased exponent
|
| 71 |
+
if temp < 1: # Subnormal
|
| 72 |
+
biased_exp = "00000000"
|
| 73 |
+
else: # Normal
|
| 74 |
+
biased_exp = format(exp + self.bias(), "08b")
|
| 75 |
+
|
| 76 |
+
# Calculate mantissa bits
|
| 77 |
+
if temp < 1: # Subnormal
|
| 78 |
+
mantissa_value = int(temp * (2 ** self.mantissa_bits()))
|
| 79 |
+
else: # Normal
|
| 80 |
+
mantissa_value = int((temp - 1) * (2 ** self.mantissa_bits()))
|
| 81 |
+
|
| 82 |
+
mantissa = format(mantissa_value, f"0{self.mantissa_bits()}b")
|
| 83 |
+
|
| 84 |
+
return f"{sign}.{biased_exp}.{mantissa}"
|
| 85 |
+
|
| 86 |
+
def _binary_to_decimal(self, binary: str) -> float:
|
| 87 |
+
"""Convert binary string in IEEE 754 format to decimal"""
|
| 88 |
+
# Clean up binary string
|
| 89 |
+
binary = "".join(c for c in binary if c in "01")
|
| 90 |
+
|
| 91 |
+
# Extract components
|
| 92 |
+
sign = -1 if binary[0] == "1" else 1
|
| 93 |
+
exp = binary[1:9]
|
| 94 |
+
mantissa = binary[9:]
|
| 95 |
+
|
| 96 |
+
# Handle special cases
|
| 97 |
+
if (
|
| 98 |
+
exp == "11111111" and mantissa == "11111111111111111111111"
|
| 99 |
+
): # NaN representation
|
| 100 |
+
return float("nan")
|
| 101 |
+
if exp == "00000000" and mantissa == "00000000000000000000000":
|
| 102 |
+
return 0.0
|
| 103 |
+
|
| 104 |
+
# Convert biased exponent
|
| 105 |
+
biased_exp = int(exp, 2)
|
| 106 |
+
|
| 107 |
+
if biased_exp == 0: # Subnormal number
|
| 108 |
+
actual_exp = -126
|
| 109 |
+
mantissa_value = int(mantissa, 2) / (2 ** self.mantissa_bits())
|
| 110 |
+
return sign * (2**actual_exp) * mantissa_value
|
| 111 |
+
else: # Normal number
|
| 112 |
+
actual_exp = biased_exp - self.bias()
|
| 113 |
+
mantissa_value = 1 + int(mantissa, 2) / (2 ** self.mantissa_bits())
|
| 114 |
+
return sign * (2**actual_exp) * mantissa_value
|
| 115 |
+
|
| 116 |
+
@classmethod
|
| 117 |
+
def from_bits(cls, bits: int) -> "Float32":
|
| 118 |
+
"""Create Float32 from 32-bit integer"""
|
| 119 |
+
return cls(binint=bits)
|
| 120 |
+
|
| 121 |
+
@classmethod
|
| 122 |
+
def nan(cls) -> "Float32":
|
| 123 |
+
"""Create NaN value"""
|
| 124 |
+
return cls(binary="1.11111111.11111111111111111111111")
|
| 125 |
+
|
| 126 |
+
@classmethod
|
| 127 |
+
def max_value(cls) -> "Float32":
|
| 128 |
+
"""Create maximum representable value"""
|
| 129 |
+
return cls(binary="0.11111110.11111111111111111111111")
|
| 130 |
+
|
| 131 |
+
@classmethod
|
| 132 |
+
def min_value(cls) -> "Float32":
|
| 133 |
+
"""Create minimum representable normal value"""
|
| 134 |
+
return cls(binary="0.00000001.00000000000000000000000")
|
| 135 |
+
|
| 136 |
+
@classmethod
|
| 137 |
+
def min_subnormal(cls) -> "Float32":
|
| 138 |
+
"""Create minimum representable subnormal value"""
|
| 139 |
+
return cls(binary="0.00000000.00000000000000000000001")
|
| 140 |
+
|
| 141 |
+
def detailed_breakdown(self) -> dict:
|
| 142 |
+
"""Provide detailed breakdown of the Float32 number components"""
|
| 143 |
+
binary = "".join(c for c in self.binary if c in "01")
|
| 144 |
+
sign_bit = int(binary[0])
|
| 145 |
+
exp_bits = binary[1:9]
|
| 146 |
+
mantissa_bits = binary[9:]
|
| 147 |
+
|
| 148 |
+
exp_val = int(exp_bits, 2)
|
| 149 |
+
mantissa_val = int(mantissa_bits, 2)
|
| 150 |
+
|
| 151 |
+
is_normal = exp_val != 0 and exp_val != 255
|
| 152 |
+
is_subnormal = exp_val == 0 and mantissa_val != 0
|
| 153 |
+
is_zero = exp_val == 0 and mantissa_val == 0
|
| 154 |
+
is_nan = (
|
| 155 |
+
exp_val == 255 and mantissa_val == 8388607
|
| 156 |
+
) # Only s.11111111.11111111111111111111111 is NaN
|
| 157 |
+
|
| 158 |
+
return {
|
| 159 |
+
"binary": self.binary,
|
| 160 |
+
"sign": sign_bit,
|
| 161 |
+
"exponent_bits": exp_bits,
|
| 162 |
+
"exponent_value": (exp_val - self.bias() if exp_val != 0 else "subnormal"),
|
| 163 |
+
"mantissa_bits": mantissa_bits,
|
| 164 |
+
"mantissa_value": mantissa_val,
|
| 165 |
+
"decimal_approx": self.decimal_approx,
|
| 166 |
+
"original_value": self.original_value,
|
| 167 |
+
"is_normal": is_normal,
|
| 168 |
+
"is_subnormal": is_subnormal,
|
| 169 |
+
"is_zero": is_zero,
|
| 170 |
+
"is_nan": is_nan,
|
| 171 |
+
"normalized_value": (
|
| 172 |
+
(1 + mantissa_val / 8388608) if is_normal else (mantissa_val / 8388608)
|
| 173 |
+
),
|
| 174 |
+
}
|
hardware_accelerators/dtypes/float8.py
CHANGED
|
@@ -23,6 +23,10 @@ class Float8(BaseFloat):
|
|
| 23 |
min_subnormal=2**-6 * (1 / 8), # from 0.0000.001
|
| 24 |
)
|
| 25 |
|
|
|
|
|
|
|
|
|
|
|
|
|
| 26 |
def _decimal_to_binary(self, num: float) -> str:
|
| 27 |
"""Convert decimal number to binary string in E4M3 format"""
|
| 28 |
if num == 0:
|
|
|
|
| 23 |
min_subnormal=2**-6 * (1 / 8), # from 0.0000.001
|
| 24 |
)
|
| 25 |
|
| 26 |
+
@classmethod
|
| 27 |
+
def binary_max(cls) -> int:
|
| 28 |
+
return 0b01111110
|
| 29 |
+
|
| 30 |
def _decimal_to_binary(self, num: float) -> str:
|
| 31 |
"""Convert decimal number to binary string in E4M3 format"""
|
| 32 |
if num == 0:
|
hardware_accelerators/nn/lmul.py
ADDED
|
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchvision
|
| 5 |
+
from torch.utils.data import DataLoader
|
| 6 |
+
from torchvision import datasets, transforms
|
| 7 |
+
import math
|
| 8 |
+
import time
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
# Custom approximate matrix multiplication using lmul
|
| 12 |
+
def lmul_matmul(A: torch.Tensor, B: torch.Tensor, dtype=torch.float32):
|
| 13 |
+
"""
|
| 14 |
+
Approximate matrix multiplication between A (m x n) and B (n x p)
|
| 15 |
+
using bitwise operations to mimic multiplication.
|
| 16 |
+
"""
|
| 17 |
+
if dtype == torch.float32:
|
| 18 |
+
# reinterpret bits as uint32 then convert to int64 for arithmetic
|
| 19 |
+
A_int = A.contiguous().view(torch.uint32).to(torch.int64)
|
| 20 |
+
B_int = B.contiguous().view(torch.uint32).to(torch.int64)
|
| 21 |
+
offset = 1064828928 # offset for float32
|
| 22 |
+
elif dtype == torch.bfloat16:
|
| 23 |
+
A_int = A.contiguous().view(torch.uint16).to(torch.int64)
|
| 24 |
+
B_int = B.contiguous().view(torch.uint16).to(torch.int64)
|
| 25 |
+
offset = 16248 # offset for bfloat16
|
| 26 |
+
else:
|
| 27 |
+
raise ValueError("Unsupported dtype")
|
| 28 |
+
|
| 29 |
+
# A is (m, n) and B is (n, p).
|
| 30 |
+
# Expand dims so that:
|
| 31 |
+
# A_int becomes (m, n, 1) and B_int becomes (1, n, p)
|
| 32 |
+
prod_int = A_int.unsqueeze(2) + B_int.unsqueeze(0) - offset # shape: (m, n, p)
|
| 33 |
+
|
| 34 |
+
# Convert the integer result back to floating point.
|
| 35 |
+
if dtype == torch.float32:
|
| 36 |
+
prod = prod_int.to(torch.uint32).view(torch.float32)
|
| 37 |
+
else:
|
| 38 |
+
prod = prod_int.to(torch.uint16).view(torch.bfloat16)
|
| 39 |
+
|
| 40 |
+
# Sum over the inner dimension to complete the dot product.
|
| 41 |
+
return prod.sum(dim=1)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
# Custom linear layer that uses lmul-based matrix multiplication
|
| 45 |
+
class LmulLinear(nn.Module):
|
| 46 |
+
def __init__(self, in_features, out_features, bias=True, dtype=torch.float32):
|
| 47 |
+
super(LmulLinear, self).__init__()
|
| 48 |
+
self.in_features = in_features
|
| 49 |
+
self.out_features = out_features
|
| 50 |
+
self.dtype = dtype
|
| 51 |
+
self.weight = nn.Parameter(torch.Tensor(out_features, in_features))
|
| 52 |
+
if bias:
|
| 53 |
+
self.bias = nn.Parameter(torch.Tensor(out_features))
|
| 54 |
+
else:
|
| 55 |
+
self.register_parameter("bias", None)
|
| 56 |
+
self.reset_parameters()
|
| 57 |
+
|
| 58 |
+
def reset_parameters(self):
|
| 59 |
+
# Initialize weights similarly to nn.Linear.
|
| 60 |
+
nn.init.kaiming_uniform_(self.weight, a=math.sqrt(5))
|
| 61 |
+
if self.bias is not None:
|
| 62 |
+
fan_in, _ = nn.init._calculate_fan_in_and_fan_out(self.weight)
|
| 63 |
+
bound = 1 / math.sqrt(fan_in)
|
| 64 |
+
nn.init.uniform_(self.bias, -bound, bound)
|
| 65 |
+
|
| 66 |
+
def forward(self, input):
|
| 67 |
+
# Compute the approximate matrix multiply:
|
| 68 |
+
# Note: input shape is (batch, in_features)
|
| 69 |
+
# weight.T shape is (in_features, out_features)
|
| 70 |
+
out = lmul_matmul(input, self.weight.t(), self.dtype)
|
| 71 |
+
if self.bias is not None:
|
| 72 |
+
out = out + self.bias # add bias as usual
|
| 73 |
+
return out
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
# MLP model using our custom lmul-based linear layers
|
| 77 |
+
class LmulMLP(nn.Module):
|
| 78 |
+
def __init__(self, input_size, hidden_size, num_classes, dtype=torch.float32):
|
| 79 |
+
super(LmulMLP, self).__init__()
|
| 80 |
+
self.flatten = nn.Flatten()
|
| 81 |
+
self.fc1 = LmulLinear(input_size, hidden_size, bias=True, dtype=dtype)
|
| 82 |
+
self.relu = nn.ReLU()
|
| 83 |
+
self.fc2 = LmulLinear(hidden_size, num_classes, bias=True, dtype=dtype)
|
| 84 |
+
|
| 85 |
+
def forward(self, x):
|
| 86 |
+
x = self.flatten(x)
|
| 87 |
+
x = self.fc1(x)
|
| 88 |
+
x = self.relu(x)
|
| 89 |
+
x = self.fc2(x)
|
| 90 |
+
return x
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
# Setup: use float32 for this example.
|
| 94 |
+
dtype = torch.float32
|
| 95 |
+
|
| 96 |
+
# Instantiate the model.
|
| 97 |
+
# For MNIST: input size is 28x28 = 784, hidden layer of 128, output 10 classes.
|
| 98 |
+
model = LmulMLP(input_size=784, hidden_size=128, num_classes=10, dtype=dtype)
|
| 99 |
+
model.eval() # set model to evaluation mode
|
| 100 |
+
|
| 101 |
+
model.load_state_dict(torch.load("models/mlp_mnist_fp32.pth", weights_only=True))
|
| 102 |
+
|
| 103 |
+
# Prepare the MNIST test dataset.
|
| 104 |
+
transform = transforms.Compose(
|
| 105 |
+
[
|
| 106 |
+
transforms.ToTensor(),
|
| 107 |
+
transforms.Normalize((0.1307,), (0.3081,)),
|
| 108 |
+
]
|
| 109 |
+
)
|
| 110 |
+
test_dataset = datasets.MNIST(
|
| 111 |
+
root="./data", train=False, transform=transform, download=True
|
| 112 |
+
)
|
| 113 |
+
test_loader = DataLoader(test_dataset, batch_size=32, shuffle=False)
|
| 114 |
+
|
| 115 |
+
# Run inference on the test dataset and measure accuracy.
|
| 116 |
+
correct = 0
|
| 117 |
+
total = 0
|
| 118 |
+
start_time = time.time()
|
| 119 |
+
|
| 120 |
+
with torch.no_grad():
|
| 121 |
+
for images, labels in test_loader:
|
| 122 |
+
# Ensure images are in the right dtype
|
| 123 |
+
images = images.to(dtype)
|
| 124 |
+
outputs = model(images)
|
| 125 |
+
# Compute predictions
|
| 126 |
+
_, predicted = torch.max(outputs, 1)
|
| 127 |
+
total += labels.size(0)
|
| 128 |
+
correct += (predicted.cpu() == labels).sum().item()
|
| 129 |
+
|
| 130 |
+
end_time = time.time()
|
| 131 |
+
accuracy = correct / total * 100
|
| 132 |
+
inference_time = end_time - start_time
|
| 133 |
+
|
| 134 |
+
print(f"Test Accuracy: {accuracy:.2f}%")
|
| 135 |
+
print(f"Inference Time on Test Set: {inference_time:.2f} seconds")
|
hardware_accelerators/nn/precision.py
ADDED
|
@@ -0,0 +1,264 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import os
|
| 2 |
+
import torch
|
| 3 |
+
import torch.nn as nn
|
| 4 |
+
import torch.optim as optim
|
| 5 |
+
from torchvision import datasets, transforms
|
| 6 |
+
from torch.utils.data import DataLoader
|
| 7 |
+
from tqdm.auto import tqdm
|
| 8 |
+
|
| 9 |
+
from .util import get_pytorch_device
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
# Define the MLP model (unchanged)
|
| 13 |
+
class MLP(nn.Module):
|
| 14 |
+
def __init__(self, input_size, hidden_size, num_classes):
|
| 15 |
+
super(MLP, self).__init__()
|
| 16 |
+
self.flatten = nn.Flatten()
|
| 17 |
+
self.fc1 = nn.Linear(input_size, hidden_size)
|
| 18 |
+
self.relu = nn.ReLU()
|
| 19 |
+
self.fc2 = nn.Linear(hidden_size, num_classes)
|
| 20 |
+
|
| 21 |
+
def forward(self, x):
|
| 22 |
+
x = self.flatten(x)
|
| 23 |
+
x = self.fc1(x)
|
| 24 |
+
x = self.relu(x)
|
| 25 |
+
out = self.fc2(x)
|
| 26 |
+
return out
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
# Helper function: adjust data to match the target dtype
|
| 30 |
+
def convert_input(data, precision):
|
| 31 |
+
if precision == "fp16":
|
| 32 |
+
return data.half()
|
| 33 |
+
elif precision == "bf16":
|
| 34 |
+
return data.to(torch.bfloat16)
|
| 35 |
+
elif precision == "fp8":
|
| 36 |
+
# Note: torch.float8_e4m3 is experimental and may not be available
|
| 37 |
+
return data.to(torch.float8_e4m3fn)
|
| 38 |
+
return data # fp32 (no conversion)
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
# Training for one epoch
|
| 42 |
+
def train_epoch(model, device, train_loader, optimizer, criterion, precision):
|
| 43 |
+
model.train()
|
| 44 |
+
running_loss = 0.0
|
| 45 |
+
progress_bar = tqdm(train_loader, desc="Training", leave=False)
|
| 46 |
+
for data, target in progress_bar:
|
| 47 |
+
# Convert inputs to the desired precision (targets remain integer)
|
| 48 |
+
data = convert_input(data, precision)
|
| 49 |
+
data, target = data.to(device), target.to(device)
|
| 50 |
+
optimizer.zero_grad()
|
| 51 |
+
|
| 52 |
+
# Forward pass
|
| 53 |
+
outputs = model(data)
|
| 54 |
+
loss = criterion(outputs, target)
|
| 55 |
+
|
| 56 |
+
# Check for NaN and skip problematic batches
|
| 57 |
+
if torch.isnan(loss):
|
| 58 |
+
print("NaN loss detected in batch, skipping...")
|
| 59 |
+
continue
|
| 60 |
+
|
| 61 |
+
# Backward and optimize with gradient clipping
|
| 62 |
+
loss.backward()
|
| 63 |
+
|
| 64 |
+
# Apply gradient clipping to prevent exploding gradients
|
| 65 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
|
| 66 |
+
|
| 67 |
+
optimizer.step()
|
| 68 |
+
running_loss += loss.item()
|
| 69 |
+
|
| 70 |
+
if len(train_loader) > 0:
|
| 71 |
+
return running_loss / len(train_loader)
|
| 72 |
+
return 0.0
|
| 73 |
+
|
| 74 |
+
|
| 75 |
+
# Evaluation loop
|
| 76 |
+
def evaluate(model, device, test_loader, criterion, precision):
|
| 77 |
+
model.eval()
|
| 78 |
+
total_loss = 0.0
|
| 79 |
+
correct = 0
|
| 80 |
+
total = 0
|
| 81 |
+
with torch.no_grad():
|
| 82 |
+
for data, target in test_loader:
|
| 83 |
+
data = convert_input(data, precision)
|
| 84 |
+
data, target = data.to(device), target.to(device)
|
| 85 |
+
outputs = model(data)
|
| 86 |
+
loss = criterion(outputs, target)
|
| 87 |
+
total_loss += loss.item()
|
| 88 |
+
_, predicted = torch.max(outputs, 1)
|
| 89 |
+
total += target.size(0)
|
| 90 |
+
correct += (predicted == target).sum().item()
|
| 91 |
+
avg_loss = total_loss / len(test_loader)
|
| 92 |
+
accuracy = 100.0 * correct / total
|
| 93 |
+
return avg_loss, accuracy
|
| 94 |
+
|
| 95 |
+
|
| 96 |
+
# Main training function for a given precision variant
|
| 97 |
+
def train_model(
|
| 98 |
+
precision,
|
| 99 |
+
batch_size=32,
|
| 100 |
+
hidden_size=128,
|
| 101 |
+
num_epochs=5,
|
| 102 |
+
learning_rate=0.001,
|
| 103 |
+
optimizer_name="adam",
|
| 104 |
+
weight_decay=0,
|
| 105 |
+
eps=1e-4,
|
| 106 |
+
model_save_path=None,
|
| 107 |
+
):
|
| 108 |
+
print(f"\nTraining in {precision.upper()} mode:")
|
| 109 |
+
device = get_pytorch_device()
|
| 110 |
+
|
| 111 |
+
# Data transformation: images are loaded as FP32 by default
|
| 112 |
+
transform = transforms.Compose(
|
| 113 |
+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
| 114 |
+
)
|
| 115 |
+
|
| 116 |
+
train_dataset = datasets.MNIST(
|
| 117 |
+
root="./data", train=True, download=True, transform=transform
|
| 118 |
+
)
|
| 119 |
+
test_dataset = datasets.MNIST(
|
| 120 |
+
root="./data", train=False, download=True, transform=transform
|
| 121 |
+
)
|
| 122 |
+
train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
|
| 123 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 124 |
+
|
| 125 |
+
# Hyperparameters
|
| 126 |
+
input_size = 28 * 28 # MNIST images are 28x28
|
| 127 |
+
num_classes = 10
|
| 128 |
+
|
| 129 |
+
# Create the model and send to device
|
| 130 |
+
model = MLP(input_size, hidden_size, num_classes).to(device)
|
| 131 |
+
|
| 132 |
+
# Convert the model to the target precision (natively)
|
| 133 |
+
if precision == "fp16":
|
| 134 |
+
model = model.to(torch.float16)
|
| 135 |
+
# Use a smaller learning rate for half precision if not explicitly specified
|
| 136 |
+
if learning_rate == 0.001: # If using the default value
|
| 137 |
+
learning_rate = 1e-4 # Lower learning rate for stability
|
| 138 |
+
elif precision == "bf16":
|
| 139 |
+
model = model.to(torch.bfloat16)
|
| 140 |
+
elif precision == "fp8":
|
| 141 |
+
# Ensure your PyTorch build/hardware supports float8_e4m3; otherwise, this will error.
|
| 142 |
+
model = model.to(torch.float8_e4m3fn)
|
| 143 |
+
# else, fp32 is already the default
|
| 144 |
+
|
| 145 |
+
# Select optimizer based on user input
|
| 146 |
+
if optimizer_name.lower() == "adam":
|
| 147 |
+
optimizer = optim.Adam(
|
| 148 |
+
model.parameters(), lr=learning_rate, eps=eps, weight_decay=weight_decay
|
| 149 |
+
)
|
| 150 |
+
elif optimizer_name.lower() == "sgd":
|
| 151 |
+
optimizer = optim.SGD(
|
| 152 |
+
model.parameters(),
|
| 153 |
+
lr=learning_rate,
|
| 154 |
+
momentum=0.9,
|
| 155 |
+
weight_decay=weight_decay,
|
| 156 |
+
)
|
| 157 |
+
elif optimizer_name.lower() == "adamw":
|
| 158 |
+
optimizer = optim.AdamW(
|
| 159 |
+
model.parameters(), lr=learning_rate, eps=eps, weight_decay=weight_decay
|
| 160 |
+
)
|
| 161 |
+
else:
|
| 162 |
+
print(f"Unknown optimizer: {optimizer_name}, defaulting to Adam")
|
| 163 |
+
optimizer = optim.Adam(
|
| 164 |
+
model.parameters(), lr=learning_rate, eps=eps, weight_decay=weight_decay
|
| 165 |
+
)
|
| 166 |
+
|
| 167 |
+
criterion = nn.CrossEntropyLoss()
|
| 168 |
+
|
| 169 |
+
print(
|
| 170 |
+
f"Training with: batch_size={batch_size}, hidden_size={hidden_size}, "
|
| 171 |
+
f"epochs={num_epochs}, lr={learning_rate}, optimizer={optimizer_name}"
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
# Training loop
|
| 175 |
+
for epoch in range(1, num_epochs + 1):
|
| 176 |
+
train_loss = train_epoch(
|
| 177 |
+
model, device, train_loader, optimizer, criterion, precision
|
| 178 |
+
)
|
| 179 |
+
|
| 180 |
+
# Check for NaN loss
|
| 181 |
+
if torch.isnan(torch.tensor([train_loss])):
|
| 182 |
+
print(f"NaN detected in epoch {epoch}, reducing learning rate")
|
| 183 |
+
for param_group in optimizer.param_groups:
|
| 184 |
+
param_group["lr"] *= 0.5
|
| 185 |
+
|
| 186 |
+
print(f"Epoch {epoch} Train Loss: {train_loss:.4f}")
|
| 187 |
+
|
| 188 |
+
# Evaluation on test set
|
| 189 |
+
test_loss, test_accuracy = evaluate(
|
| 190 |
+
model, device, test_loader, criterion, precision
|
| 191 |
+
)
|
| 192 |
+
print(
|
| 193 |
+
f"{precision.upper()} Test Loss: {test_loss:.4f}, Test Accuracy: {test_accuracy:.2f}%"
|
| 194 |
+
)
|
| 195 |
+
|
| 196 |
+
# Optionally, save the model
|
| 197 |
+
if model_save_path:
|
| 198 |
+
save_path = model_save_path
|
| 199 |
+
else:
|
| 200 |
+
model_dir = "models"
|
| 201 |
+
os.makedirs(model_dir, exist_ok=True)
|
| 202 |
+
save_path = os.path.join(model_dir, f"mlp_mnist_{precision}.pth")
|
| 203 |
+
|
| 204 |
+
torch.save(model.state_dict(), save_path)
|
| 205 |
+
print(f"Model saved to {save_path}\n")
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
# Main script to train a model in a specific precision
|
| 209 |
+
if __name__ == "__main__":
|
| 210 |
+
import argparse
|
| 211 |
+
|
| 212 |
+
parser = argparse.ArgumentParser(
|
| 213 |
+
description="Train MNIST model in a specific precision"
|
| 214 |
+
)
|
| 215 |
+
parser.add_argument(
|
| 216 |
+
"--dtype",
|
| 217 |
+
type=str,
|
| 218 |
+
default="fp32",
|
| 219 |
+
choices=["fp32", "fp16", "bf16", "fp8"],
|
| 220 |
+
help="Precision type to train in (fp32, fp16, bf16, fp8)",
|
| 221 |
+
)
|
| 222 |
+
parser.add_argument(
|
| 223 |
+
"--batch-size", type=int, default=32, help="Batch size for training"
|
| 224 |
+
)
|
| 225 |
+
parser.add_argument(
|
| 226 |
+
"--hidden-size", type=int, default=128, help="Hidden layer size for MLP"
|
| 227 |
+
)
|
| 228 |
+
parser.add_argument(
|
| 229 |
+
"--epochs", type=int, default=5, help="Number of training epochs"
|
| 230 |
+
)
|
| 231 |
+
parser.add_argument("--lr", type=float, default=0.001, help="Learning rate")
|
| 232 |
+
parser.add_argument(
|
| 233 |
+
"--optimizer",
|
| 234 |
+
type=str,
|
| 235 |
+
default="adam",
|
| 236 |
+
choices=["adam", "sgd", "adamw"],
|
| 237 |
+
help="Optimizer to use for training",
|
| 238 |
+
)
|
| 239 |
+
parser.add_argument(
|
| 240 |
+
"--weight-decay", type=float, default=0, help="Weight decay (L2 penalty)"
|
| 241 |
+
)
|
| 242 |
+
parser.add_argument(
|
| 243 |
+
"--eps", type=float, default=1e-4, help="Epsilon for Adam optimizer"
|
| 244 |
+
)
|
| 245 |
+
parser.add_argument(
|
| 246 |
+
"--save-path", type=str, default=None, help="Path to save the trained model"
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
args = parser.parse_args()
|
| 250 |
+
|
| 251 |
+
try:
|
| 252 |
+
train_model(
|
| 253 |
+
precision=args.dtype,
|
| 254 |
+
batch_size=args.batch_size,
|
| 255 |
+
hidden_size=args.hidden_size,
|
| 256 |
+
num_epochs=args.epochs,
|
| 257 |
+
learning_rate=args.lr,
|
| 258 |
+
optimizer_name=args.optimizer,
|
| 259 |
+
weight_decay=args.weight_decay,
|
| 260 |
+
eps=args.eps,
|
| 261 |
+
model_save_path=args.save_path,
|
| 262 |
+
)
|
| 263 |
+
except Exception as e:
|
| 264 |
+
print(f"Error training {args.dtype.upper()} model: {e}")
|
hardware_accelerators/nn/precision_eval.py
ADDED
|
@@ -0,0 +1,280 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
from torchvision import datasets, transforms
|
| 4 |
+
from torch.utils.data import DataLoader
|
| 5 |
+
import numpy as np
|
| 6 |
+
import time
|
| 7 |
+
from pathlib import Path
|
| 8 |
+
|
| 9 |
+
from .precision import MLP
|
| 10 |
+
from .util import get_pytorch_device
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def load_mnist_data(batch_size=100):
|
| 14 |
+
"""Load MNIST test dataset"""
|
| 15 |
+
transform = transforms.Compose(
|
| 16 |
+
[transforms.ToTensor(), transforms.Normalize((0.1307,), (0.3081,))]
|
| 17 |
+
)
|
| 18 |
+
test_dataset = datasets.MNIST(
|
| 19 |
+
root="./data", train=False, download=True, transform=transform
|
| 20 |
+
)
|
| 21 |
+
test_loader = DataLoader(test_dataset, batch_size=batch_size, shuffle=False)
|
| 22 |
+
return test_loader
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
def create_model(precision):
|
| 26 |
+
"""Create MLP model with specified precision"""
|
| 27 |
+
input_size = 28 * 28 # MNIST images are 28x28
|
| 28 |
+
hidden_size = 128
|
| 29 |
+
num_classes = 10
|
| 30 |
+
device = get_pytorch_device()
|
| 31 |
+
|
| 32 |
+
model = MLP(input_size, hidden_size, num_classes).to(device)
|
| 33 |
+
|
| 34 |
+
# Convert model to target precision
|
| 35 |
+
if precision == "fp16":
|
| 36 |
+
model = model.to(torch.float16)
|
| 37 |
+
elif precision == "bf16":
|
| 38 |
+
model = model.to(torch.bfloat16)
|
| 39 |
+
elif precision == "fp32":
|
| 40 |
+
model = model.to(torch.float32)
|
| 41 |
+
|
| 42 |
+
return model, device
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def load_model_weights(model, model_path):
|
| 46 |
+
"""Load model weights from checkpoint"""
|
| 47 |
+
model.load_state_dict(torch.load(model_path))
|
| 48 |
+
return model
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
def evaluate_model(model, test_loader, device, precision):
|
| 52 |
+
"""Evaluate model accuracy and inference time"""
|
| 53 |
+
model.eval()
|
| 54 |
+
correct = 0
|
| 55 |
+
total = 0
|
| 56 |
+
|
| 57 |
+
# For measuring inference time
|
| 58 |
+
start_time = time.time()
|
| 59 |
+
|
| 60 |
+
with torch.no_grad():
|
| 61 |
+
for data, target in test_loader:
|
| 62 |
+
# Convert input to specified precision
|
| 63 |
+
if precision == "fp16":
|
| 64 |
+
data = data.half()
|
| 65 |
+
elif precision == "bf16":
|
| 66 |
+
data = data.to(torch.bfloat16)
|
| 67 |
+
|
| 68 |
+
data, target = data.to(device), target.to(device)
|
| 69 |
+
|
| 70 |
+
# Forward pass
|
| 71 |
+
outputs = model(data)
|
| 72 |
+
|
| 73 |
+
# Calculate accuracy
|
| 74 |
+
_, predicted = torch.max(outputs, 1)
|
| 75 |
+
total += target.size(0)
|
| 76 |
+
correct += (predicted == target).sum().item()
|
| 77 |
+
|
| 78 |
+
inference_time = time.time() - start_time
|
| 79 |
+
accuracy = 100.0 * correct / total
|
| 80 |
+
|
| 81 |
+
return {
|
| 82 |
+
"accuracy": accuracy,
|
| 83 |
+
"inference_time": inference_time,
|
| 84 |
+
"correct": correct,
|
| 85 |
+
"total": total,
|
| 86 |
+
}
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
def compare_precision_inference(fp32_model_path):
|
| 90 |
+
"""Compare FP32 trained model inference in different precisions"""
|
| 91 |
+
print("\nEvaluating FP32-trained model inference with different precisions")
|
| 92 |
+
print("-" * 80)
|
| 93 |
+
|
| 94 |
+
# Load test data
|
| 95 |
+
test_loader = load_mnist_data()
|
| 96 |
+
|
| 97 |
+
# Verify the model file exists
|
| 98 |
+
model_path = Path(fp32_model_path)
|
| 99 |
+
if not model_path.exists():
|
| 100 |
+
print(f"Error: Model file {fp32_model_path} not found!")
|
| 101 |
+
return
|
| 102 |
+
|
| 103 |
+
# Create models in different precisions
|
| 104 |
+
precisions = ["fp32", "bf16"]
|
| 105 |
+
results = {}
|
| 106 |
+
|
| 107 |
+
for precision in precisions:
|
| 108 |
+
print(f"\nTesting inference in {precision.upper()} mode...")
|
| 109 |
+
|
| 110 |
+
# Create model in specified precision
|
| 111 |
+
model, device = create_model(precision)
|
| 112 |
+
|
| 113 |
+
# Load weights from the FP32-trained model
|
| 114 |
+
# When loading to a BF16 model, the weights will be automatically cast
|
| 115 |
+
model = load_model_weights(model, fp32_model_path)
|
| 116 |
+
|
| 117 |
+
# Evaluate model
|
| 118 |
+
results[precision] = evaluate_model(model, test_loader, device, precision)
|
| 119 |
+
|
| 120 |
+
# Print results
|
| 121 |
+
print(f"Accuracy: {results[precision]['accuracy']:.2f}%")
|
| 122 |
+
print(f"Inference time: {results[precision]['inference_time']:.4f} seconds")
|
| 123 |
+
|
| 124 |
+
# Calculate and print comparison metrics
|
| 125 |
+
print("\nComparison Summary")
|
| 126 |
+
print("-" * 80)
|
| 127 |
+
|
| 128 |
+
fp32_results = results["fp32"]
|
| 129 |
+
bf16_results = results["bf16"]
|
| 130 |
+
|
| 131 |
+
acc_diff = bf16_results["accuracy"] - fp32_results["accuracy"]
|
| 132 |
+
time_ratio = fp32_results["inference_time"] / bf16_results["inference_time"]
|
| 133 |
+
|
| 134 |
+
print(f"Accuracy drop from FP32 to BF16: {acc_diff:.2f}%")
|
| 135 |
+
print(f"Inference speedup with BF16: {time_ratio:.2f}x")
|
| 136 |
+
|
| 137 |
+
return results
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def detailed_precision_comparison(
|
| 141 |
+
fp32_model_path, trials=3, batch_sizes=[1, 16, 32, 64, 128, 256]
|
| 142 |
+
):
|
| 143 |
+
"""Run detailed comparison with multiple batch sizes and trials"""
|
| 144 |
+
print("\nDetailed Precision Comparison")
|
| 145 |
+
print("=" * 80)
|
| 146 |
+
|
| 147 |
+
# Verify the model file exists
|
| 148 |
+
model_path = Path(fp32_model_path)
|
| 149 |
+
if not model_path.exists():
|
| 150 |
+
print(f"Error: Model file {fp32_model_path} not found!")
|
| 151 |
+
return
|
| 152 |
+
|
| 153 |
+
precisions = ["fp32", "bf16"]
|
| 154 |
+
all_results = {}
|
| 155 |
+
|
| 156 |
+
for precision in precisions:
|
| 157 |
+
all_results[precision] = {}
|
| 158 |
+
|
| 159 |
+
print(f"\nEvaluating {precision.upper()} precision")
|
| 160 |
+
print("-" * 60)
|
| 161 |
+
|
| 162 |
+
# Create model in specified precision only once
|
| 163 |
+
model, device = create_model(precision)
|
| 164 |
+
model = load_model_weights(model, fp32_model_path)
|
| 165 |
+
|
| 166 |
+
# Warm up the GPU/CPU
|
| 167 |
+
print("Warming up...")
|
| 168 |
+
dummy_loader = load_mnist_data(batch_size=64)
|
| 169 |
+
with torch.no_grad():
|
| 170 |
+
for data, _ in dummy_loader:
|
| 171 |
+
if precision == "bf16":
|
| 172 |
+
data = data.to(torch.bfloat16)
|
| 173 |
+
elif precision == "fp16":
|
| 174 |
+
data = data.half()
|
| 175 |
+
data = data.to(device)
|
| 176 |
+
_ = model(data)
|
| 177 |
+
break
|
| 178 |
+
|
| 179 |
+
# Run trials for different batch sizes
|
| 180 |
+
for batch_size in batch_sizes:
|
| 181 |
+
print(f"\n Batch size: {batch_size}")
|
| 182 |
+
test_loader = load_mnist_data(batch_size=batch_size)
|
| 183 |
+
|
| 184 |
+
batch_results = {"accuracy": [], "inference_time": []}
|
| 185 |
+
|
| 186 |
+
for trial in range(trials):
|
| 187 |
+
print(f" Trial {trial+1}/{trials}...", end="", flush=True)
|
| 188 |
+
result = evaluate_model(model, test_loader, device, precision)
|
| 189 |
+
batch_results["accuracy"].append(result["accuracy"])
|
| 190 |
+
batch_results["inference_time"].append(result["inference_time"])
|
| 191 |
+
print(
|
| 192 |
+
f" done. Time: {result['inference_time']:.4f}s, Accuracy: {result['accuracy']:.2f}%"
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# Calculate averages
|
| 196 |
+
avg_accuracy = sum(batch_results["accuracy"]) / len(
|
| 197 |
+
batch_results["accuracy"]
|
| 198 |
+
)
|
| 199 |
+
avg_time = sum(batch_results["inference_time"]) / len(
|
| 200 |
+
batch_results["inference_time"]
|
| 201 |
+
)
|
| 202 |
+
|
| 203 |
+
all_results[precision][batch_size] = {
|
| 204 |
+
"avg_accuracy": avg_accuracy,
|
| 205 |
+
"avg_inference_time": avg_time,
|
| 206 |
+
"trials": batch_results,
|
| 207 |
+
}
|
| 208 |
+
|
| 209 |
+
print(f" Average: {avg_time:.4f}s, Accuracy: {avg_accuracy:.2f}%")
|
| 210 |
+
|
| 211 |
+
# Print comparison table
|
| 212 |
+
print("\nComparison Results")
|
| 213 |
+
print("=" * 80)
|
| 214 |
+
print(
|
| 215 |
+
f"{'Batch Size':^10} | {'FP32 Acc':^10} | {'BF16 Acc':^10} | {'Acc Diff':^10} | {'FP32 Time':^10} | {'BF16 Time':^10} | {'Speedup':^10}"
|
| 216 |
+
)
|
| 217 |
+
print("-" * 80)
|
| 218 |
+
|
| 219 |
+
for batch_size in batch_sizes:
|
| 220 |
+
fp32_acc = all_results["fp32"][batch_size]["avg_accuracy"]
|
| 221 |
+
bf16_acc = all_results["bf16"][batch_size]["avg_accuracy"]
|
| 222 |
+
acc_diff = bf16_acc - fp32_acc
|
| 223 |
+
|
| 224 |
+
fp32_time = all_results["fp32"][batch_size]["avg_inference_time"]
|
| 225 |
+
bf16_time = all_results["bf16"][batch_size]["avg_inference_time"]
|
| 226 |
+
speedup = fp32_time / bf16_time
|
| 227 |
+
|
| 228 |
+
print(
|
| 229 |
+
f"{batch_size:^10} | {fp32_acc:^10.2f} | {bf16_acc:^10.2f} | {acc_diff:^10.2f} | {fp32_time:^10.4f} | {bf16_time:^10.4f} | {speedup:^10.2f}x"
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Calculate and print overall averages
|
| 233 |
+
avg_acc_diff = sum(
|
| 234 |
+
all_results["bf16"][bs]["avg_accuracy"]
|
| 235 |
+
- all_results["fp32"][bs]["avg_accuracy"]
|
| 236 |
+
for bs in batch_sizes
|
| 237 |
+
) / len(batch_sizes)
|
| 238 |
+
avg_speedup = sum(
|
| 239 |
+
all_results["fp32"][bs]["avg_inference_time"]
|
| 240 |
+
/ all_results["bf16"][bs]["avg_inference_time"]
|
| 241 |
+
for bs in batch_sizes
|
| 242 |
+
) / len(batch_sizes)
|
| 243 |
+
|
| 244 |
+
print("-" * 80)
|
| 245 |
+
print(f"Average accuracy difference across all batch sizes: {avg_acc_diff:.2f}%")
|
| 246 |
+
print(f"Average speedup across all batch sizes: {avg_speedup:.2f}x")
|
| 247 |
+
|
| 248 |
+
return all_results
|
| 249 |
+
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
import argparse
|
| 253 |
+
|
| 254 |
+
parser = argparse.ArgumentParser(
|
| 255 |
+
description="Compare model inference with FP32 vs BF16"
|
| 256 |
+
)
|
| 257 |
+
parser.add_argument(
|
| 258 |
+
"--model_path",
|
| 259 |
+
type=str,
|
| 260 |
+
default="models/mlp_mnist_fp32.pth",
|
| 261 |
+
help="Path to FP32 trained model weights",
|
| 262 |
+
)
|
| 263 |
+
parser.add_argument(
|
| 264 |
+
"--detailed",
|
| 265 |
+
action="store_true",
|
| 266 |
+
help="Run detailed comparison with multiple batch sizes",
|
| 267 |
+
)
|
| 268 |
+
parser.add_argument(
|
| 269 |
+
"--trials",
|
| 270 |
+
type=int,
|
| 271 |
+
default=3,
|
| 272 |
+
help="Number of trials to run (only with --detailed)",
|
| 273 |
+
)
|
| 274 |
+
|
| 275 |
+
args = parser.parse_args()
|
| 276 |
+
|
| 277 |
+
if args.detailed:
|
| 278 |
+
detailed_precision_comparison(args.model_path, trials=args.trials)
|
| 279 |
+
else:
|
| 280 |
+
compare_precision_inference(args.model_path)
|
hardware_accelerators/nn/run_precision_comparison.py
ADDED
|
@@ -0,0 +1,78 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python3
|
| 2 |
+
|
| 3 |
+
import os
|
| 4 |
+
import argparse
|
| 5 |
+
from hardware_accelerators.nn.precision import train_model
|
| 6 |
+
from hardware_accelerators.nn.precision_eval import (
|
| 7 |
+
compare_precision_inference,
|
| 8 |
+
detailed_precision_comparison,
|
| 9 |
+
)
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def main():
|
| 13 |
+
# Parse command line arguments
|
| 14 |
+
parser = argparse.ArgumentParser(
|
| 15 |
+
description="Train and evaluate precision differences"
|
| 16 |
+
)
|
| 17 |
+
parser.add_argument(
|
| 18 |
+
"--detailed",
|
| 19 |
+
action="store_true",
|
| 20 |
+
help="Run detailed comparison with multiple batch sizes and trials",
|
| 21 |
+
)
|
| 22 |
+
parser.add_argument(
|
| 23 |
+
"--trials",
|
| 24 |
+
type=int,
|
| 25 |
+
default=3,
|
| 26 |
+
help="Number of trials to run for each batch size (only with --detailed)",
|
| 27 |
+
)
|
| 28 |
+
parser.add_argument(
|
| 29 |
+
"--force-train",
|
| 30 |
+
action="store_true",
|
| 31 |
+
help="Force training a new model even if one exists",
|
| 32 |
+
)
|
| 33 |
+
parser.add_argument(
|
| 34 |
+
"--batch-sizes",
|
| 35 |
+
nargs="+",
|
| 36 |
+
type=int,
|
| 37 |
+
default=[1, 16, 32, 64, 128, 256],
|
| 38 |
+
help="Batch sizes to test (only with --detailed)",
|
| 39 |
+
)
|
| 40 |
+
args = parser.parse_args()
|
| 41 |
+
|
| 42 |
+
# Create models directory if it doesn't exist
|
| 43 |
+
os.makedirs("models", exist_ok=True)
|
| 44 |
+
|
| 45 |
+
# Path to save/load model
|
| 46 |
+
model_path = "models/mlp_mnist_fp32.pth"
|
| 47 |
+
|
| 48 |
+
# Check if model exists, train if not or if forced
|
| 49 |
+
if not os.path.exists(model_path) or args.force_train:
|
| 50 |
+
print("Training a new FP32 model...")
|
| 51 |
+
# Train a model in FP32 precision
|
| 52 |
+
train_model(
|
| 53 |
+
precision="fp32",
|
| 54 |
+
batch_size=64,
|
| 55 |
+
hidden_size=128,
|
| 56 |
+
num_epochs=5,
|
| 57 |
+
learning_rate=0.001,
|
| 58 |
+
optimizer_name="adam",
|
| 59 |
+
model_save_path=model_path,
|
| 60 |
+
)
|
| 61 |
+
else:
|
| 62 |
+
print(f"Using existing model at {model_path}")
|
| 63 |
+
|
| 64 |
+
# Run comparison
|
| 65 |
+
if args.detailed:
|
| 66 |
+
print(
|
| 67 |
+
f"Running detailed comparison with {args.trials} trials for each batch size..."
|
| 68 |
+
)
|
| 69 |
+
detailed_precision_comparison(
|
| 70 |
+
model_path, trials=args.trials, batch_sizes=args.batch_sizes
|
| 71 |
+
)
|
| 72 |
+
else:
|
| 73 |
+
# Run basic comparison
|
| 74 |
+
compare_precision_inference(model_path)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
if __name__ == "__main__":
|
| 78 |
+
main()
|
hardware_accelerators/nn/train.py
CHANGED
|
@@ -8,8 +8,6 @@ from tqdm.auto import tqdm
|
|
| 8 |
|
| 9 |
from .util import model_factory, get_pytorch_device # progress bar for notebooks
|
| 10 |
|
| 11 |
-
# from pytorch2tikz import Architecture
|
| 12 |
-
|
| 13 |
|
| 14 |
# Training function for one epoch
|
| 15 |
def train(model, device, train_loader, optimizer, criterion, epoch, num_epochs):
|
|
|
|
| 8 |
|
| 9 |
from .util import model_factory, get_pytorch_device # progress bar for notebooks
|
| 10 |
|
|
|
|
|
|
|
| 11 |
|
| 12 |
# Training function for one epoch
|
| 13 |
def train(model, device, train_loader, optimizer, criterion, epoch, num_epochs):
|
hardware_accelerators/nn/util.py
CHANGED
|
@@ -27,8 +27,10 @@ def load_model(model_path: str, device: torch.device | None = None):
|
|
| 27 |
if device is None:
|
| 28 |
device = get_pytorch_device()
|
| 29 |
model = model_factory()
|
|
|
|
|
|
|
|
|
|
| 30 |
model.to(device)
|
| 31 |
-
model.load_state_dict(torch.load(model_path, map_location=device))
|
| 32 |
return model
|
| 33 |
|
| 34 |
|
|
|
|
| 27 |
if device is None:
|
| 28 |
device = get_pytorch_device()
|
| 29 |
model = model_factory()
|
| 30 |
+
model.load_state_dict(
|
| 31 |
+
torch.load(model_path, map_location=device, weights_only=True)
|
| 32 |
+
)
|
| 33 |
model.to(device)
|
|
|
|
| 34 |
return model
|
| 35 |
|
| 36 |
|
hardware_accelerators/rtllib/__init__.py
CHANGED
|
@@ -9,9 +9,13 @@ from .accelerator import (
|
|
| 9 |
TiledMatrixEngine,
|
| 10 |
AcceleratorConfig,
|
| 11 |
Accelerator,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
)
|
| 13 |
|
| 14 |
-
|
| 15 |
"float_adder",
|
| 16 |
"FloatAdderPipelined",
|
| 17 |
"float_multiplier",
|
|
@@ -20,11 +24,15 @@ all = [
|
|
| 20 |
"lmul_fast",
|
| 21 |
"LmulPipelined",
|
| 22 |
"SystolicArrayDiP",
|
| 23 |
-
"
|
| 24 |
"BufferMemory",
|
| 25 |
"WeightFIFO",
|
| 26 |
"TiledAcceleratorConfig",
|
| 27 |
"TiledMatrixEngine",
|
| 28 |
"AcceleratorConfig",
|
| 29 |
"Accelerator",
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
]
|
|
|
|
| 9 |
TiledMatrixEngine,
|
| 10 |
AcceleratorConfig,
|
| 11 |
Accelerator,
|
| 12 |
+
CompiledAcceleratorConfig,
|
| 13 |
+
CompiledAccelerator,
|
| 14 |
+
AcceleratorAnalysisConfig,
|
| 15 |
+
AcceleratorTopLevel,
|
| 16 |
)
|
| 17 |
|
| 18 |
+
__all__ = [
|
| 19 |
"float_adder",
|
| 20 |
"FloatAdderPipelined",
|
| 21 |
"float_multiplier",
|
|
|
|
| 24 |
"lmul_fast",
|
| 25 |
"LmulPipelined",
|
| 26 |
"SystolicArrayDiP",
|
| 27 |
+
"TiledAccumulatorMemoryBank",
|
| 28 |
"BufferMemory",
|
| 29 |
"WeightFIFO",
|
| 30 |
"TiledAcceleratorConfig",
|
| 31 |
"TiledMatrixEngine",
|
| 32 |
"AcceleratorConfig",
|
| 33 |
"Accelerator",
|
| 34 |
+
"CompiledAcceleratorConfig",
|
| 35 |
+
"CompiledAccelerator",
|
| 36 |
+
"AcceleratorAnalysisConfig",
|
| 37 |
+
"AcceleratorTopLevel",
|
| 38 |
]
|
hardware_accelerators/rtllib/accelerator.py
CHANGED
|
@@ -1,99 +1,140 @@
|
|
|
|
|
| 1 |
from dataclasses import dataclass
|
| 2 |
-
from typing import Callable, Type, Dict
|
| 3 |
import numpy as np
|
|
|
|
|
|
|
| 4 |
|
| 5 |
from pyrtl import (
|
| 6 |
WireVector,
|
| 7 |
Register,
|
|
|
|
| 8 |
Output,
|
| 9 |
Simulation,
|
|
|
|
| 10 |
concat,
|
| 11 |
)
|
| 12 |
|
|
|
|
| 13 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 14 |
from .buffer import BufferMemory, WeightFIFO
|
| 15 |
from .systolic import SystolicArrayDiP
|
| 16 |
from .accumulators import Accumulator, TiledAccumulatorMemoryBank
|
| 17 |
-
from .activations import ReluUnit
|
| 18 |
from ..dtypes import BaseFloat
|
| 19 |
|
|
|
|
| 20 |
|
| 21 |
-
@dataclass
|
| 22 |
-
class AcceleratorConfig:
|
| 23 |
-
"""Configuration class for a systolic array accelerator.
|
| 24 |
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
|
| 28 |
-
"""
|
| 29 |
|
| 30 |
array_size: int
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
num_weight_tiles: int
|
| 34 |
-
"""Number of weight tiles in the FIFO. Each tile is equal to the size of the systolic array"""
|
| 35 |
-
|
| 36 |
-
data_type: Type[BaseFloat]
|
| 37 |
-
"""Floating point format of input data to systolic array"""
|
| 38 |
-
|
| 39 |
weight_type: Type[BaseFloat]
|
| 40 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 41 |
|
| 42 |
-
|
| 43 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 44 |
|
| 45 |
-
|
| 46 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 47 |
|
| 48 |
-
|
| 49 |
-
|
| 50 |
|
| 51 |
-
|
| 52 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 53 |
|
| 54 |
-
|
| 55 |
-
|
|
|
|
|
|
|
| 56 |
|
| 57 |
-
|
| 58 |
-
|
| 59 |
|
| 60 |
@property
|
| 61 |
-
def
|
| 62 |
-
"""Get
|
| 63 |
-
return (self.
|
| 64 |
|
| 65 |
|
| 66 |
-
class
|
| 67 |
-
def __init__(self, config:
|
| 68 |
self.config = config
|
| 69 |
|
| 70 |
# Instantiate hardware components
|
| 71 |
-
self.fifo = WeightFIFO(
|
| 72 |
-
array_size=config.array_size,
|
| 73 |
-
num_tiles=config.num_weight_tiles,
|
| 74 |
-
dtype=config.weight_type,
|
| 75 |
-
)
|
| 76 |
self.systolic_array = SystolicArrayDiP(
|
| 77 |
size=config.array_size,
|
| 78 |
-
data_type=config.
|
| 79 |
weight_type=config.weight_type,
|
| 80 |
-
accum_type=config.
|
| 81 |
-
multiplier=config.
|
| 82 |
-
adder=
|
| 83 |
-
pipeline=config.
|
| 84 |
)
|
| 85 |
self.accumulator = Accumulator(
|
| 86 |
-
addr_width=
|
| 87 |
array_size=config.array_size,
|
| 88 |
-
data_type=config.
|
| 89 |
-
adder=
|
| 90 |
)
|
| 91 |
self.activation = ReluUnit(
|
| 92 |
size=config.array_size,
|
| 93 |
-
dtype=config.
|
| 94 |
)
|
| 95 |
self.outputs = [
|
| 96 |
-
WireVector(config.
|
|
|
|
| 97 |
]
|
| 98 |
|
| 99 |
# Connect components
|
|
@@ -103,12 +144,15 @@ class Accelerator:
|
|
| 103 |
"""Create unnamed WireVectors for control signals"""
|
| 104 |
self.data_enable = WireVector(1)
|
| 105 |
self.data_ins = [
|
| 106 |
-
WireVector(self.config.
|
| 107 |
for _ in range(self.config.array_size)
|
| 108 |
]
|
| 109 |
|
| 110 |
-
self.
|
| 111 |
-
self.
|
|
|
|
|
|
|
|
|
|
| 112 |
|
| 113 |
self.accum_addr_in = WireVector(self.config.accum_addr_width)
|
| 114 |
self.accum_mode_in = WireVector(1)
|
|
@@ -117,7 +161,7 @@ class Accelerator:
|
|
| 117 |
self.act_func_in = WireVector(1) # Apply activation function or passthrough
|
| 118 |
|
| 119 |
def _create_pipeline_registers(self):
|
| 120 |
-
num_registers = self.config.array_size + 1 + int(self.config.
|
| 121 |
|
| 122 |
self.accum_addr_regs = [
|
| 123 |
Register(self.config.accum_addr_width) for _ in range(num_registers)
|
|
@@ -153,16 +197,11 @@ class Accelerator:
|
|
| 153 |
self._create_pipeline_registers()
|
| 154 |
|
| 155 |
# Connect buffer to external inputs
|
| 156 |
-
self.fifo.connect_inputs(
|
| 157 |
-
start=self.weight_start_in,
|
| 158 |
-
tile_addr=self.weight_tile_addr_in,
|
| 159 |
-
)
|
| 160 |
-
|
| 161 |
self.systolic_array.connect_inputs(
|
| 162 |
data_inputs=self.data_ins,
|
| 163 |
enable_input=self.data_enable,
|
| 164 |
-
weight_inputs=self.
|
| 165 |
-
weight_enable=self.
|
| 166 |
)
|
| 167 |
|
| 168 |
# Connect accumulator to systolic array
|
|
@@ -188,8 +227,8 @@ class Accelerator:
|
|
| 188 |
self,
|
| 189 |
data_enable: WireVector | None = None,
|
| 190 |
data_inputs: list[WireVector] | None = None,
|
| 191 |
-
|
| 192 |
-
|
| 193 |
accum_addr: WireVector | None = None,
|
| 194 |
accum_mode: WireVector | None = None,
|
| 195 |
act_start: WireVector | None = None,
|
|
@@ -204,9 +243,8 @@ class Accelerator:
|
|
| 204 |
Args:
|
| 205 |
data_enable: 1-bit signal that enables data flow into the systolic array
|
| 206 |
data_inputs: List of input data wires for the systolic array. Must match array_size
|
| 207 |
-
|
| 208 |
-
|
| 209 |
-
Width must match the FIFO's tile address width
|
| 210 |
accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
|
| 211 |
accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
|
| 212 |
act_start: 1-bit signal to enable passing data through the activation unit
|
|
@@ -226,22 +264,27 @@ class Accelerator:
|
|
| 226 |
f"Expected {self.config.array_size}, got {len(data_inputs)}"
|
| 227 |
)
|
| 228 |
for i, wire in enumerate(data_inputs):
|
| 229 |
-
assert len(wire) == self.config.
|
| 230 |
f"Data input width mismatch. "
|
| 231 |
-
f"Expected {self.config.
|
| 232 |
)
|
| 233 |
self.data_ins[i] <<= wire
|
| 234 |
|
| 235 |
-
if
|
| 236 |
-
assert len(
|
| 237 |
-
self.
|
| 238 |
|
| 239 |
-
if
|
| 240 |
-
assert len(
|
| 241 |
-
f"
|
| 242 |
-
f"Expected {self.
|
| 243 |
)
|
| 244 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 245 |
|
| 246 |
if accum_addr is not None:
|
| 247 |
assert len(accum_addr) == self.config.accum_addr_width, (
|
|
@@ -278,12 +321,8 @@ class Accelerator:
|
|
| 278 |
assert len(valid) == 1, "Output valid signal must be a single bit wire"
|
| 279 |
valid <<= self.activation.outputs_valid
|
| 280 |
|
| 281 |
-
def
|
| 282 |
-
"""Return
|
| 283 |
-
return self.systolic_array.get_state(sim)
|
| 284 |
-
|
| 285 |
-
def inspect_accumulator_state(self, sim: Simulation) -> np.ndarray:
|
| 286 |
-
"""Return all accumulator tiles as 3D array.
|
| 287 |
|
| 288 |
Args:
|
| 289 |
sim: PyRTL simulation instance
|
|
@@ -296,18 +335,237 @@ class Accelerator:
|
|
| 296 |
tiles = []
|
| 297 |
for addr in range(2**self.config.accum_addr_width):
|
| 298 |
row = [
|
| 299 |
-
float(
|
|
|
|
|
|
|
|
|
|
|
|
|
| 300 |
for bank in self.accumulator.memory_banks
|
| 301 |
]
|
| 302 |
tiles.append(row)
|
| 303 |
return np.array(tiles)
|
| 304 |
|
| 305 |
|
| 306 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 307 |
def __init__(self, config: AcceleratorConfig):
|
| 308 |
self.config = config
|
| 309 |
|
| 310 |
# Instantiate hardware components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 311 |
self.systolic_array = SystolicArrayDiP(
|
| 312 |
size=config.array_size,
|
| 313 |
data_type=config.data_type,
|
|
@@ -342,11 +600,8 @@ class CompiledAccelerator:
|
|
| 342 |
for _ in range(self.config.array_size)
|
| 343 |
]
|
| 344 |
|
| 345 |
-
self.
|
| 346 |
-
self.
|
| 347 |
-
WireVector(self.config.weight_type.bitwidth())
|
| 348 |
-
for _ in range(self.config.array_size)
|
| 349 |
-
]
|
| 350 |
|
| 351 |
self.accum_addr_in = WireVector(self.config.accum_addr_width)
|
| 352 |
self.accum_mode_in = WireVector(1)
|
|
@@ -367,23 +622,33 @@ class CompiledAccelerator:
|
|
| 367 |
self.accum_mode_out = WireVector(1)
|
| 368 |
self.accum_mode_out <<= self.accum_mode_regs[-1]
|
| 369 |
|
| 370 |
-
self.
|
| 371 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 372 |
|
| 373 |
self.accum_addr_regs[0].next <<= self.accum_addr_in
|
| 374 |
self.accum_mode_regs[0].next <<= self.accum_mode_in
|
| 375 |
for i in range(1, len(self.accum_addr_regs)):
|
| 376 |
self.accum_addr_regs[i].next <<= self.accum_addr_regs[i - 1]
|
| 377 |
self.accum_mode_regs[i].next <<= self.accum_mode_regs[i - 1]
|
| 378 |
-
self.act_control_regs[i].next <<= self.act_control_regs[i - 1]
|
|
|
|
|
|
|
|
|
|
| 379 |
|
| 380 |
self.act_addr = Register(self.config.accum_addr_width)
|
| 381 |
self.act_func = Register(1)
|
| 382 |
self.act_start = Register(1)
|
| 383 |
|
| 384 |
self.act_addr.next <<= self.accum_addr_out
|
| 385 |
-
self.act_func.next <<= self.act_control_regs[-1][0]
|
| 386 |
-
self.act_start.next <<= self.act_control_regs[-1][1]
|
|
|
|
|
|
|
| 387 |
|
| 388 |
def _connect_components(self):
|
| 389 |
"""Internal component connections"""
|
|
@@ -391,11 +656,16 @@ class CompiledAccelerator:
|
|
| 391 |
self._create_pipeline_registers()
|
| 392 |
|
| 393 |
# Connect buffer to external inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 394 |
self.systolic_array.connect_inputs(
|
| 395 |
data_inputs=self.data_ins,
|
| 396 |
enable_input=self.data_enable,
|
| 397 |
-
weight_inputs=self.
|
| 398 |
-
weight_enable=self.
|
| 399 |
)
|
| 400 |
|
| 401 |
# Connect accumulator to systolic array
|
|
@@ -421,8 +691,8 @@ class CompiledAccelerator:
|
|
| 421 |
self,
|
| 422 |
data_enable: WireVector | None = None,
|
| 423 |
data_inputs: list[WireVector] | None = None,
|
| 424 |
-
|
| 425 |
-
|
| 426 |
accum_addr: WireVector | None = None,
|
| 427 |
accum_mode: WireVector | None = None,
|
| 428 |
act_start: WireVector | None = None,
|
|
@@ -437,8 +707,9 @@ class CompiledAccelerator:
|
|
| 437 |
Args:
|
| 438 |
data_enable: 1-bit signal that enables data flow into the systolic array
|
| 439 |
data_inputs: List of input data wires for the systolic array. Must match array_size
|
| 440 |
-
|
| 441 |
-
|
|
|
|
| 442 |
accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
|
| 443 |
accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
|
| 444 |
act_start: 1-bit signal to enable passing data through the activation unit
|
|
@@ -464,21 +735,16 @@ class CompiledAccelerator:
|
|
| 464 |
)
|
| 465 |
self.data_ins[i] <<= wire
|
| 466 |
|
| 467 |
-
if
|
| 468 |
-
assert len(
|
| 469 |
-
self.
|
| 470 |
|
| 471 |
-
if
|
| 472 |
-
assert len(
|
| 473 |
-
f"
|
| 474 |
-
f"Expected {self.
|
| 475 |
)
|
| 476 |
-
|
| 477 |
-
assert len(wire) == self.config.weight_type.bitwidth(), (
|
| 478 |
-
f"Weight input wire width mismatch. "
|
| 479 |
-
f"Expected {self.config.weight_type.bitwidth()}, got {len(wire)}"
|
| 480 |
-
)
|
| 481 |
-
self.weights_in[i] <<= wire
|
| 482 |
|
| 483 |
if accum_addr is not None:
|
| 484 |
assert len(accum_addr) == self.config.accum_addr_width, (
|
|
@@ -515,6 +781,34 @@ class CompiledAccelerator:
|
|
| 515 |
assert len(valid) == 1, "Output valid signal must be a single bit wire"
|
| 516 |
valid <<= self.activation.outputs_valid
|
| 517 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 518 |
|
| 519 |
@dataclass
|
| 520 |
class TiledAcceleratorConfig:
|
|
|
|
| 1 |
+
from __future__ import annotations
|
| 2 |
from dataclasses import dataclass
|
| 3 |
+
from typing import Callable, Literal, Type, Dict
|
| 4 |
import numpy as np
|
| 5 |
+
import hashlib
|
| 6 |
+
import json
|
| 7 |
|
| 8 |
from pyrtl import (
|
| 9 |
WireVector,
|
| 10 |
Register,
|
| 11 |
+
Input,
|
| 12 |
Output,
|
| 13 |
Simulation,
|
| 14 |
+
CompiledSimulation,
|
| 15 |
concat,
|
| 16 |
)
|
| 17 |
|
| 18 |
+
from .adders import float_adder
|
| 19 |
|
| 20 |
+
from ..dtypes.bfloat16 import BF16
|
| 21 |
+
|
| 22 |
+
from .multipliers import *
|
| 23 |
+
from .adders import *
|
| 24 |
+
from .lmul import *
|
| 25 |
from .buffer import BufferMemory, WeightFIFO
|
| 26 |
from .systolic import SystolicArrayDiP
|
| 27 |
from .accumulators import Accumulator, TiledAccumulatorMemoryBank
|
| 28 |
+
from .activations import ReluState, ReluUnit
|
| 29 |
from ..dtypes import BaseFloat
|
| 30 |
|
| 31 |
+
from dataclasses import dataclass
|
| 32 |
|
|
|
|
|
|
|
|
|
|
| 33 |
|
| 34 |
+
@dataclass # (frozen=True)
|
| 35 |
+
class CompiledAcceleratorConfig:
|
| 36 |
+
"""Configuration for a compiled accelerator."""
|
|
|
|
| 37 |
|
| 38 |
array_size: int
|
| 39 |
+
activation_type: Type[BaseFloat]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 40 |
weight_type: Type[BaseFloat]
|
| 41 |
+
multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
|
| 42 |
+
accum_addr_width: int = 12 # 4096 accumulator slots
|
| 43 |
+
pipeline_pe: bool = False
|
| 44 |
+
|
| 45 |
+
def __post_init__(self):
|
| 46 |
+
"""Validate configuration after initialization."""
|
| 47 |
+
# Ensure activation dtype has bitwidth >= weight dtype
|
| 48 |
+
if self.activation_type.bitwidth() < self.weight_type.bitwidth():
|
| 49 |
+
raise ValueError(
|
| 50 |
+
f"Activation dtype bitwidth ({self.activation_type.bitwidth()}) must be greater than or equal to "
|
| 51 |
+
f"weight dtype bitwidth ({self.weight_type.bitwidth()})"
|
| 52 |
+
)
|
| 53 |
|
| 54 |
+
@property
|
| 55 |
+
def name(self):
|
| 56 |
+
dtype_name = lambda d: d.bitwidth() if d != BF16 else "b16"
|
| 57 |
+
lmul = "-lmul" if "lmul" in self.multiplier.__name__.lower() else ""
|
| 58 |
+
mem = f"-m{self.accum_addr_width}" if self.accum_addr_width != 12 else ""
|
| 59 |
+
return (
|
| 60 |
+
f"w{dtype_name(self.weight_type)}"
|
| 61 |
+
f"a{dtype_name(self.activation_type)}"
|
| 62 |
+
f"-{self.array_size}x{self.array_size}"
|
| 63 |
+
f"{lmul}"
|
| 64 |
+
f"{'-p' if self.pipeline_pe else ''}"
|
| 65 |
+
f"{mem}"
|
| 66 |
+
)
|
| 67 |
|
| 68 |
+
def __repr__(self) -> str:
|
| 69 |
+
return (
|
| 70 |
+
"CompiledAcceleratorConfig(\n"
|
| 71 |
+
f"\tarray_size={self.array_size}\n"
|
| 72 |
+
f"\tactivation_type={self.activation_type.__name__}\n"
|
| 73 |
+
f"\tweight_type={self.weight_type.__name__}\n"
|
| 74 |
+
f"\tmultiplier={self.multiplier.__name__}\n"
|
| 75 |
+
f"\taccum_addr_width={self.accum_addr_width}\n"
|
| 76 |
+
f"\tpipeline={self.pipeline_pe}\n"
|
| 77 |
+
# f'\tname="{self.name}"\n'
|
| 78 |
+
")"
|
| 79 |
+
)
|
| 80 |
|
| 81 |
+
def __hash__(self) -> int:
|
| 82 |
+
"""Generate a consistent hash value for this configuration.
|
| 83 |
|
| 84 |
+
Returns:
|
| 85 |
+
An integer hash value.
|
| 86 |
+
"""
|
| 87 |
+
# Create a dictionary of the key configuration parameters
|
| 88 |
+
config_dict = {
|
| 89 |
+
"array_size": self.array_size,
|
| 90 |
+
"activation_type": f"{self.activation_type.__module__}.{self.activation_type.__name__}",
|
| 91 |
+
"weight_type": f"{self.weight_type.__module__}.{self.weight_type.__name__}",
|
| 92 |
+
"multiplier": self.multiplier.__name__,
|
| 93 |
+
"accum_addr_width": self.accum_addr_width,
|
| 94 |
+
"pipeline": self.pipeline_pe,
|
| 95 |
+
}
|
| 96 |
|
| 97 |
+
# Generate a hash from the sorted JSON representation
|
| 98 |
+
hash_str = hashlib.sha256(
|
| 99 |
+
json.dumps(config_dict, sort_keys=True).encode()
|
| 100 |
+
).hexdigest()
|
| 101 |
|
| 102 |
+
# Convert the first 16 characters of the hex string to an integer
|
| 103 |
+
return int(hash_str[:16], 16)
|
| 104 |
|
| 105 |
@property
|
| 106 |
+
def id(self):
|
| 107 |
+
"""Get a unique hexadecimal identifier for this configuration."""
|
| 108 |
+
return hex(self.__hash__())[2:]
|
| 109 |
|
| 110 |
|
| 111 |
+
class CompiledAccelerator:
|
| 112 |
+
def __init__(self, config: CompiledAcceleratorConfig):
|
| 113 |
self.config = config
|
| 114 |
|
| 115 |
# Instantiate hardware components
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 116 |
self.systolic_array = SystolicArrayDiP(
|
| 117 |
size=config.array_size,
|
| 118 |
+
data_type=config.activation_type,
|
| 119 |
weight_type=config.weight_type,
|
| 120 |
+
accum_type=config.activation_type,
|
| 121 |
+
multiplier=config.multiplier,
|
| 122 |
+
adder=float_adder,
|
| 123 |
+
pipeline=config.pipeline_pe,
|
| 124 |
)
|
| 125 |
self.accumulator = Accumulator(
|
| 126 |
+
addr_width=12,
|
| 127 |
array_size=config.array_size,
|
| 128 |
+
data_type=config.activation_type,
|
| 129 |
+
adder=float_adder,
|
| 130 |
)
|
| 131 |
self.activation = ReluUnit(
|
| 132 |
size=config.array_size,
|
| 133 |
+
dtype=config.activation_type,
|
| 134 |
)
|
| 135 |
self.outputs = [
|
| 136 |
+
WireVector(config.activation_type.bitwidth())
|
| 137 |
+
for _ in range(config.array_size)
|
| 138 |
]
|
| 139 |
|
| 140 |
# Connect components
|
|
|
|
| 144 |
"""Create unnamed WireVectors for control signals"""
|
| 145 |
self.data_enable = WireVector(1)
|
| 146 |
self.data_ins = [
|
| 147 |
+
WireVector(self.config.activation_type.bitwidth())
|
| 148 |
for _ in range(self.config.array_size)
|
| 149 |
]
|
| 150 |
|
| 151 |
+
self.weight_enable = WireVector(1)
|
| 152 |
+
self.weights_in = [
|
| 153 |
+
WireVector(self.config.weight_type.bitwidth())
|
| 154 |
+
for _ in range(self.config.array_size)
|
| 155 |
+
]
|
| 156 |
|
| 157 |
self.accum_addr_in = WireVector(self.config.accum_addr_width)
|
| 158 |
self.accum_mode_in = WireVector(1)
|
|
|
|
| 161 |
self.act_func_in = WireVector(1) # Apply activation function or passthrough
|
| 162 |
|
| 163 |
def _create_pipeline_registers(self):
|
| 164 |
+
num_registers = self.config.array_size + 1 + int(self.config.pipeline_pe)
|
| 165 |
|
| 166 |
self.accum_addr_regs = [
|
| 167 |
Register(self.config.accum_addr_width) for _ in range(num_registers)
|
|
|
|
| 197 |
self._create_pipeline_registers()
|
| 198 |
|
| 199 |
# Connect buffer to external inputs
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 200 |
self.systolic_array.connect_inputs(
|
| 201 |
data_inputs=self.data_ins,
|
| 202 |
enable_input=self.data_enable,
|
| 203 |
+
weight_inputs=self.weights_in,
|
| 204 |
+
weight_enable=self.weight_enable,
|
| 205 |
)
|
| 206 |
|
| 207 |
# Connect accumulator to systolic array
|
|
|
|
| 227 |
self,
|
| 228 |
data_enable: WireVector | None = None,
|
| 229 |
data_inputs: list[WireVector] | None = None,
|
| 230 |
+
weight_enable: WireVector | None = None,
|
| 231 |
+
weights_in: list[WireVector] | None = None,
|
| 232 |
accum_addr: WireVector | None = None,
|
| 233 |
accum_mode: WireVector | None = None,
|
| 234 |
act_start: WireVector | None = None,
|
|
|
|
| 243 |
Args:
|
| 244 |
data_enable: 1-bit signal that enables data flow into the systolic array
|
| 245 |
data_inputs: List of input data wires for the systolic array. Must match array_size
|
| 246 |
+
weight_enable: 1-bit signal enable writing new weights to systolic array registers
|
| 247 |
+
weights_in: List of input weight wires for the systolic array. Must match array_size
|
|
|
|
| 248 |
accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
|
| 249 |
accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
|
| 250 |
act_start: 1-bit signal to enable passing data through the activation unit
|
|
|
|
| 264 |
f"Expected {self.config.array_size}, got {len(data_inputs)}"
|
| 265 |
)
|
| 266 |
for i, wire in enumerate(data_inputs):
|
| 267 |
+
assert len(wire) == self.config.activation_type.bitwidth(), (
|
| 268 |
f"Data input width mismatch. "
|
| 269 |
+
f"Expected {self.config.activation_type.bitwidth()}, got {len(wire)}"
|
| 270 |
)
|
| 271 |
self.data_ins[i] <<= wire
|
| 272 |
|
| 273 |
+
if weight_enable is not None:
|
| 274 |
+
assert len(weight_enable) == 1, "Weight start signal must be 1 bit wide"
|
| 275 |
+
self.weight_enable <<= weight_enable
|
| 276 |
|
| 277 |
+
if weights_in is not None:
|
| 278 |
+
assert len(weights_in) == self.config.array_size, (
|
| 279 |
+
f"Weights input list length must match array size. "
|
| 280 |
+
f"Expected {self.config.array_size}, got {len(weights_in)}"
|
| 281 |
)
|
| 282 |
+
for i, wire in enumerate(weights_in):
|
| 283 |
+
assert len(wire) == self.config.weight_type.bitwidth(), (
|
| 284 |
+
f"Weight input wire width mismatch. "
|
| 285 |
+
f"Expected {self.config.weight_type.bitwidth()}, got {len(wire)}"
|
| 286 |
+
)
|
| 287 |
+
self.weights_in[i] <<= wire
|
| 288 |
|
| 289 |
if accum_addr is not None:
|
| 290 |
assert len(accum_addr) == self.config.accum_addr_width, (
|
|
|
|
| 321 |
assert len(valid) == 1, "Output valid signal must be a single bit wire"
|
| 322 |
valid <<= self.activation.outputs_valid
|
| 323 |
|
| 324 |
+
def inspect_accumulator_state(self, sim: CompiledSimulation) -> np.ndarray:
|
| 325 |
+
"""Return accumulator memory as an array.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 326 |
|
| 327 |
Args:
|
| 328 |
sim: PyRTL simulation instance
|
|
|
|
| 335 |
tiles = []
|
| 336 |
for addr in range(2**self.config.accum_addr_width):
|
| 337 |
row = [
|
| 338 |
+
float(
|
| 339 |
+
self.config.activation_type(
|
| 340 |
+
binint=sim.inspect_mem(bank).get(addr, 0)
|
| 341 |
+
)
|
| 342 |
+
)
|
| 343 |
for bank in self.accumulator.memory_banks
|
| 344 |
]
|
| 345 |
tiles.append(row)
|
| 346 |
return np.array(tiles)
|
| 347 |
|
| 348 |
|
| 349 |
+
@dataclass
|
| 350 |
+
class AcceleratorAnalysisConfig:
|
| 351 |
+
"""Configuration for an accelerator to be generated for analysis."""
|
| 352 |
+
|
| 353 |
+
array_size: int
|
| 354 |
+
"""
|
| 355 |
+
The size of the systolic array (N x N).
|
| 356 |
+
Determines the number of processing elements in the accelerator.
|
| 357 |
+
"""
|
| 358 |
+
|
| 359 |
+
weight_type: Type[BaseFloat]
|
| 360 |
+
"""
|
| 361 |
+
The floating-point data type for weights.
|
| 362 |
+
Must be a subclass of BaseFloat (e.g., Float8, BF16, Float32).
|
| 363 |
+
"""
|
| 364 |
+
|
| 365 |
+
activation_type: Type[BaseFloat]
|
| 366 |
+
"""
|
| 367 |
+
The floating-point data type for activations/inputs.
|
| 368 |
+
Must be a subclass of BaseFloat (e.g., Float8, BF16, Float32).
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
lmul: bool
|
| 372 |
+
"""
|
| 373 |
+
Whether to use L-mul for multiplication operations.
|
| 374 |
+
If True, uses linear-time multipliers; if False, uses standard IEEE multipliers.
|
| 375 |
+
"""
|
| 376 |
+
|
| 377 |
+
pipeline_level: Literal["low", "high"] | None
|
| 378 |
+
"""
|
| 379 |
+
The level of pipelining in the accelerator:
|
| 380 |
+
- None: No pipelining (fully combinational design)
|
| 381 |
+
- 'low': Basic pipelining between multiplier and adder in each PE
|
| 382 |
+
- 'high': Full pipelining with pipelined arithmetic units
|
| 383 |
+
"""
|
| 384 |
+
|
| 385 |
+
use_fast_internals: bool
|
| 386 |
+
"""
|
| 387 |
+
Whether to use faster basic arithmetic implementations with more complex low-level RTL.
|
| 388 |
+
- True: uses optimized arithmetic units from PyRTL's rtllib
|
| 389 |
+
- False: prioritize simplicity over speed
|
| 390 |
+
|
| 391 |
+
WARNING: Setting to True could potentially make final synthesis on the Verilog output worse as the synthesis tools will not be able to infer optimal circuits from the complex low-level RTL.
|
| 392 |
+
"""
|
| 393 |
+
|
| 394 |
+
accum_addr_width: int = 12
|
| 395 |
+
"""
|
| 396 |
+
The bit width of the accumulator address.
|
| 397 |
+
Determines the size of the accumulator memory (2^width entries).
|
| 398 |
+
Default is 12 bits (4096 entries).
|
| 399 |
+
"""
|
| 400 |
+
|
| 401 |
+
def __post_init__(self):
|
| 402 |
+
# Ensure activation dtype has bitwidth >= weight dtype
|
| 403 |
+
if self.activation_type.bitwidth() < self.weight_type.bitwidth():
|
| 404 |
+
raise ValueError(
|
| 405 |
+
f"Activation dtype bitwidth ({self.activation_type.bitwidth()}) must be greater than or equal to "
|
| 406 |
+
f"weight dtype bitwidth ({self.weight_type.bitwidth()})"
|
| 407 |
+
)
|
| 408 |
+
|
| 409 |
+
# Determine if we should use pipelined arithmetic functions
|
| 410 |
+
use_pipelined_funcs = self.pipeline_level == "high"
|
| 411 |
+
|
| 412 |
+
# Set pipeline_pe flag for PE configuration
|
| 413 |
+
# True if any pipeline level is specified (low or high)
|
| 414 |
+
self.pipeline_pe = self.pipeline_level is not None
|
| 415 |
+
|
| 416 |
+
# Multiplier function selection using dictionary mapping
|
| 417 |
+
multiplier_map = {
|
| 418 |
+
# (lmul, use_pipelined_funcs, fast_internals) -> function
|
| 419 |
+
(True, True, True): lmul_pipelined_fast,
|
| 420 |
+
(True, True, False): lmul_pipelined,
|
| 421 |
+
(True, False, True): lmul_fast,
|
| 422 |
+
(True, False, False): lmul_simple,
|
| 423 |
+
(False, True, True): float_multiplier_pipelined_fast_unstable,
|
| 424 |
+
(False, True, False): float_multiplier_pipelined,
|
| 425 |
+
(False, False, True): float_multiplier_fast_unstable,
|
| 426 |
+
(False, False, False): float_multiplier,
|
| 427 |
+
}
|
| 428 |
+
|
| 429 |
+
# Adder function selection using dictionary mapping
|
| 430 |
+
adder_map = {
|
| 431 |
+
# (use_pipelined_funcs, fast_internals) -> function
|
| 432 |
+
(True, True): float_adder_pipelined_fast_unstable,
|
| 433 |
+
(True, False): float_adder_pipelined,
|
| 434 |
+
(False, True): float_adder_fast_unstable,
|
| 435 |
+
(False, False): float_adder,
|
| 436 |
+
}
|
| 437 |
+
|
| 438 |
+
# Select functions using the maps
|
| 439 |
+
self.multiplier_func = multiplier_map[
|
| 440 |
+
(self.lmul, use_pipelined_funcs, self.use_fast_internals)
|
| 441 |
+
]
|
| 442 |
+
self.adder_func = adder_map[(use_pipelined_funcs, self.use_fast_internals)]
|
| 443 |
+
|
| 444 |
+
@property
|
| 445 |
+
def name(self):
|
| 446 |
+
dtype_name = lambda d: d.bitwidth() if d != BF16 else "b16"
|
| 447 |
+
mul = "-lmul" if self.lmul else "-ieee"
|
| 448 |
+
pipe_name_map = {"low": "-pipePE", "high": "-pipeALL"}
|
| 449 |
+
fast = "-fast" if self.use_fast_internals else ""
|
| 450 |
+
mem = f"-m{self.accum_addr_width}" if self.accum_addr_width != 12 else ""
|
| 451 |
+
return (
|
| 452 |
+
f"w{dtype_name(self.weight_type)}"
|
| 453 |
+
f"a{dtype_name(self.activation_type)}"
|
| 454 |
+
f"-{self.array_size}x{self.array_size}"
|
| 455 |
+
+ mem
|
| 456 |
+
+ mul
|
| 457 |
+
+ fast
|
| 458 |
+
+ pipe_name_map.get(self.pipeline_level, "") # type: ignore
|
| 459 |
+
)
|
| 460 |
+
|
| 461 |
+
|
| 462 |
+
class AcceleratorTopLevel(CompiledAccelerator):
|
| 463 |
+
def __init__(self, config: AcceleratorAnalysisConfig):
|
| 464 |
+
self.config = config
|
| 465 |
+
|
| 466 |
+
# Instantiate hardware components
|
| 467 |
+
self.systolic_array = SystolicArrayDiP(
|
| 468 |
+
size=config.array_size,
|
| 469 |
+
data_type=config.activation_type,
|
| 470 |
+
weight_type=config.weight_type,
|
| 471 |
+
accum_type=config.activation_type,
|
| 472 |
+
multiplier=config.multiplier_func,
|
| 473 |
+
adder=config.adder_func,
|
| 474 |
+
pipeline=config.pipeline_pe,
|
| 475 |
+
)
|
| 476 |
+
self.accumulator = Accumulator(
|
| 477 |
+
addr_width=12,
|
| 478 |
+
array_size=config.array_size,
|
| 479 |
+
data_type=config.activation_type,
|
| 480 |
+
adder=config.adder_func,
|
| 481 |
+
)
|
| 482 |
+
self.activation = ReluUnit(
|
| 483 |
+
size=config.array_size,
|
| 484 |
+
dtype=config.activation_type,
|
| 485 |
+
)
|
| 486 |
+
self.outputs = [
|
| 487 |
+
Output(config.activation_type.bitwidth(), f"out_{i}")
|
| 488 |
+
for i in range(config.array_size)
|
| 489 |
+
]
|
| 490 |
+
|
| 491 |
+
# Connect everything together and create io ports
|
| 492 |
+
self._connect_components()
|
| 493 |
+
self.valid_out = Output(1, "valid_out")
|
| 494 |
+
self.valid_out <<= self.activation.outputs_valid
|
| 495 |
+
|
| 496 |
+
def _create_control_wires(self):
|
| 497 |
+
"""Create named Input wires for control signals"""
|
| 498 |
+
self.data_enable = Input(1, "data_enable")
|
| 499 |
+
self.data_ins = [
|
| 500 |
+
Input(self.config.activation_type.bitwidth(), f"data_in_{i}")
|
| 501 |
+
for i in range(self.config.array_size)
|
| 502 |
+
]
|
| 503 |
+
self.weight_enable = Input(1, "weight_enable")
|
| 504 |
+
self.weights_in = [
|
| 505 |
+
Input(self.config.weight_type.bitwidth(), f"weight_in_{i}")
|
| 506 |
+
for i in range(self.config.array_size)
|
| 507 |
+
]
|
| 508 |
+
self.accum_addr_in = Input(self.config.accum_addr_width, "accum_addr_in")
|
| 509 |
+
self.accum_mode_in = Input(1, "accum_mode_in")
|
| 510 |
+
self.act_start_in = Input(1, "act_start_in")
|
| 511 |
+
self.act_func_in = Input(1, "act_func_in")
|
| 512 |
+
|
| 513 |
+
|
| 514 |
+
@dataclass(unsafe_hash=True)
|
| 515 |
+
class AcceleratorConfig:
|
| 516 |
+
"""Configuration class for a systolic array accelerator.
|
| 517 |
+
|
| 518 |
+
This class defines the parameters and specifications for a systolic array
|
| 519 |
+
accelerator including array dimensions, data types, arithmetic operations,
|
| 520 |
+
and memory configuration.
|
| 521 |
+
"""
|
| 522 |
+
|
| 523 |
+
array_size: int
|
| 524 |
+
"""Dimension of systolic array (always square)"""
|
| 525 |
+
|
| 526 |
+
num_weight_tiles: int
|
| 527 |
+
"""Number of weight tiles in the FIFO. Each tile is equal to the size of the systolic array"""
|
| 528 |
+
|
| 529 |
+
data_type: Type[BaseFloat]
|
| 530 |
+
"""Floating point format of input data to systolic array"""
|
| 531 |
+
|
| 532 |
+
weight_type: Type[BaseFloat]
|
| 533 |
+
"""Floating point format of weight inputs"""
|
| 534 |
+
|
| 535 |
+
accum_type: Type[BaseFloat]
|
| 536 |
+
"""Floating point format to accumulate values in"""
|
| 537 |
+
|
| 538 |
+
pe_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
|
| 539 |
+
"""Function to generate adder hardware for the processing elements"""
|
| 540 |
+
|
| 541 |
+
accum_adder: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
|
| 542 |
+
"""Function to generate adder hardware for the accumulator buffer"""
|
| 543 |
+
|
| 544 |
+
pe_multiplier: Callable[[WireVector, WireVector, Type[BaseFloat]], WireVector]
|
| 545 |
+
"""Function to generate multiplier hardware for the processing elements"""
|
| 546 |
+
|
| 547 |
+
pipeline: bool
|
| 548 |
+
"""Whether to add a pipeline stage in processing elements between multiplier and adder"""
|
| 549 |
+
|
| 550 |
+
accum_addr_width: int
|
| 551 |
+
"""Address width for accumulator memory. Determines number of individually addressable locations"""
|
| 552 |
+
|
| 553 |
+
@property
|
| 554 |
+
def weight_tile_addr_width(self):
|
| 555 |
+
"""Get the width of the weight tile address bus in bits"""
|
| 556 |
+
return (self.num_weight_tiles - 1).bit_length()
|
| 557 |
+
|
| 558 |
+
|
| 559 |
+
class Accelerator:
|
| 560 |
def __init__(self, config: AcceleratorConfig):
|
| 561 |
self.config = config
|
| 562 |
|
| 563 |
# Instantiate hardware components
|
| 564 |
+
self.fifo = WeightFIFO(
|
| 565 |
+
array_size=config.array_size,
|
| 566 |
+
num_tiles=config.num_weight_tiles,
|
| 567 |
+
dtype=config.weight_type,
|
| 568 |
+
)
|
| 569 |
self.systolic_array = SystolicArrayDiP(
|
| 570 |
size=config.array_size,
|
| 571 |
data_type=config.data_type,
|
|
|
|
| 600 |
for _ in range(self.config.array_size)
|
| 601 |
]
|
| 602 |
|
| 603 |
+
self.weight_start_in = WireVector(1)
|
| 604 |
+
self.weight_tile_addr_in = WireVector(self.fifo.tile_addr_width)
|
|
|
|
|
|
|
|
|
|
| 605 |
|
| 606 |
self.accum_addr_in = WireVector(self.config.accum_addr_width)
|
| 607 |
self.accum_mode_in = WireVector(1)
|
|
|
|
| 622 |
self.accum_mode_out = WireVector(1)
|
| 623 |
self.accum_mode_out <<= self.accum_mode_regs[-1]
|
| 624 |
|
| 625 |
+
self.act_start_regs = [Register(1) for _ in range(num_registers)]
|
| 626 |
+
self.act_enable_regs = [Register(1) for _ in range(num_registers)]
|
| 627 |
+
self.act_start_regs[0].next <<= self.act_start_in
|
| 628 |
+
self.act_enable_regs[0].next <<= self.act_func_in
|
| 629 |
+
|
| 630 |
+
# self.act_control_regs = [Register(2) for _ in range(num_registers)]
|
| 631 |
+
# self.act_control_regs[0].next <<= concat(self.act_start_in, self.act_func_in)
|
| 632 |
|
| 633 |
self.accum_addr_regs[0].next <<= self.accum_addr_in
|
| 634 |
self.accum_mode_regs[0].next <<= self.accum_mode_in
|
| 635 |
for i in range(1, len(self.accum_addr_regs)):
|
| 636 |
self.accum_addr_regs[i].next <<= self.accum_addr_regs[i - 1]
|
| 637 |
self.accum_mode_regs[i].next <<= self.accum_mode_regs[i - 1]
|
| 638 |
+
# self.act_control_regs[i].next <<= self.act_control_regs[i - 1]
|
| 639 |
+
if i < len(self.act_start_regs):
|
| 640 |
+
self.act_enable_regs[i].next <<= self.act_enable_regs[i - 1]
|
| 641 |
+
self.act_start_regs[i].next <<= self.act_start_regs[i - 1]
|
| 642 |
|
| 643 |
self.act_addr = Register(self.config.accum_addr_width)
|
| 644 |
self.act_func = Register(1)
|
| 645 |
self.act_start = Register(1)
|
| 646 |
|
| 647 |
self.act_addr.next <<= self.accum_addr_out
|
| 648 |
+
# self.act_func.next <<= self.act_control_regs[-1][0]
|
| 649 |
+
# self.act_start.next <<= self.act_control_regs[-1][1]
|
| 650 |
+
self.act_func.next <<= self.act_enable_regs[-1]
|
| 651 |
+
self.act_start.next <<= self.act_start_regs[-1]
|
| 652 |
|
| 653 |
def _connect_components(self):
|
| 654 |
"""Internal component connections"""
|
|
|
|
| 656 |
self._create_pipeline_registers()
|
| 657 |
|
| 658 |
# Connect buffer to external inputs
|
| 659 |
+
self.fifo.connect_inputs(
|
| 660 |
+
start=self.weight_start_in,
|
| 661 |
+
tile_addr=self.weight_tile_addr_in,
|
| 662 |
+
)
|
| 663 |
+
|
| 664 |
self.systolic_array.connect_inputs(
|
| 665 |
data_inputs=self.data_ins,
|
| 666 |
enable_input=self.data_enable,
|
| 667 |
+
weight_inputs=self.fifo.outputs.weights,
|
| 668 |
+
weight_enable=self.fifo.outputs.active,
|
| 669 |
)
|
| 670 |
|
| 671 |
# Connect accumulator to systolic array
|
|
|
|
| 691 |
self,
|
| 692 |
data_enable: WireVector | None = None,
|
| 693 |
data_inputs: list[WireVector] | None = None,
|
| 694 |
+
weight_start: WireVector | None = None,
|
| 695 |
+
weight_tile_addr: WireVector | None = None,
|
| 696 |
accum_addr: WireVector | None = None,
|
| 697 |
accum_mode: WireVector | None = None,
|
| 698 |
act_start: WireVector | None = None,
|
|
|
|
| 707 |
Args:
|
| 708 |
data_enable: 1-bit signal that enables data flow into the systolic array
|
| 709 |
data_inputs: List of input data wires for the systolic array. Must match array_size
|
| 710 |
+
weight_start: 1-bit signal that triggers loading of a new weight tile when pulsed high
|
| 711 |
+
weight_tile_addr: Address selecting which weight tile to load from the FIFO.
|
| 712 |
+
Width must match the FIFO's tile address width
|
| 713 |
accum_addr: Address for the accumulator memory bank. Width must match accum_addr_width
|
| 714 |
accum_mode: 1-bit mode select (0=overwrite, 1=accumulate with existing values)
|
| 715 |
act_start: 1-bit signal to enable passing data through the activation unit
|
|
|
|
| 735 |
)
|
| 736 |
self.data_ins[i] <<= wire
|
| 737 |
|
| 738 |
+
if weight_start is not None:
|
| 739 |
+
assert len(weight_start) == 1, "Weight start signal must be 1 bit wide"
|
| 740 |
+
self.weight_start_in <<= weight_start
|
| 741 |
|
| 742 |
+
if weight_tile_addr is not None:
|
| 743 |
+
assert len(weight_tile_addr) == self.fifo.tile_addr_width, (
|
| 744 |
+
f"Weight tile address width mismatch. "
|
| 745 |
+
f"Expected {self.fifo.tile_addr_width}, got {len(weight_tile_addr)}"
|
| 746 |
)
|
| 747 |
+
self.weight_tile_addr_in <<= weight_tile_addr
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 748 |
|
| 749 |
if accum_addr is not None:
|
| 750 |
assert len(accum_addr) == self.config.accum_addr_width, (
|
|
|
|
| 781 |
assert len(valid) == 1, "Output valid signal must be a single bit wire"
|
| 782 |
valid <<= self.activation.outputs_valid
|
| 783 |
|
| 784 |
+
def inspect_systolic_array_state(self, sim: Simulation):
|
| 785 |
+
"""Return current PE array state"""
|
| 786 |
+
return self.systolic_array.get_state(sim)
|
| 787 |
+
|
| 788 |
+
def inspect_accumulator_state(self, sim: Simulation) -> np.ndarray:
|
| 789 |
+
"""Return all accumulator tiles as 3D array.
|
| 790 |
+
|
| 791 |
+
Args:
|
| 792 |
+
sim: PyRTL simulation instance
|
| 793 |
+
|
| 794 |
+
Returns:
|
| 795 |
+
2D numpy array of shape (2**accum_addr_width, array_size) containing
|
| 796 |
+
all accumulator tile data converted to floating point values.
|
| 797 |
+
Each tile contains array_size rows with array_size columns.
|
| 798 |
+
"""
|
| 799 |
+
tiles = []
|
| 800 |
+
for addr in range(2**self.config.accum_addr_width):
|
| 801 |
+
row = [
|
| 802 |
+
float(self.config.accum_type(binint=sim.inspect_mem(bank).get(addr, 0)))
|
| 803 |
+
for bank in self.accumulator.memory_banks
|
| 804 |
+
]
|
| 805 |
+
tiles.append(row)
|
| 806 |
+
return np.array(tiles)
|
| 807 |
+
|
| 808 |
+
def inspect_activation_state(self, sim: Simulation) -> ReluState:
|
| 809 |
+
"""Return current activation unit state"""
|
| 810 |
+
return self.activation.inspect_state(sim)
|
| 811 |
+
|
| 812 |
|
| 813 |
@dataclass
|
| 814 |
class TiledAcceleratorConfig:
|
hardware_accelerators/rtllib/activations.py
CHANGED
|
@@ -1,4 +1,6 @@
|
|
| 1 |
-
from
|
|
|
|
|
|
|
| 2 |
from pyrtl import (
|
| 3 |
WireVector,
|
| 4 |
Input,
|
|
@@ -13,6 +15,32 @@ from pyrtl import (
|
|
| 13 |
from ..dtypes.base import BaseFloat
|
| 14 |
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class ReluUnit:
|
| 17 |
def __init__(self, size: int, dtype: Type[BaseFloat]):
|
| 18 |
self.size = size
|
|
@@ -21,18 +49,24 @@ class ReluUnit:
|
|
| 21 |
# Control signals
|
| 22 |
self.start = WireVector(1) # trigger to latch new enable value
|
| 23 |
self.enable_in = WireVector(1) # input enable value to latch
|
| 24 |
-
self.inputs_valid = WireVector(1) # indicates if inputs are valid
|
| 25 |
self.enable_reg = Register(1) # stateful enable register
|
|
|
|
|
|
|
|
|
|
| 26 |
|
| 27 |
# Input and output data
|
| 28 |
-
self.
|
| 29 |
-
self.
|
|
|
|
|
|
|
|
|
|
|
|
|
| 30 |
self.outputs_valid = WireVector(1)
|
| 31 |
-
self.outputs_valid <<= self.
|
| 32 |
|
| 33 |
def relu(self, x: WireVector):
|
| 34 |
# Use enable_reg instead of enable wire
|
| 35 |
-
pass_condition = self.
|
| 36 |
~self.enable_reg | (self.enable_reg & ~x[-1])
|
| 37 |
)
|
| 38 |
return select(pass_condition, x, Const(0, self.dtype.bitwidth()))
|
|
@@ -48,7 +82,7 @@ class ReluUnit:
|
|
| 48 |
len(inputs) == self.size
|
| 49 |
), f"Activation module input size mismatch. Expected {self.size}, got {len(inputs)}"
|
| 50 |
for i in range(self.size):
|
| 51 |
-
self.
|
| 52 |
self.inputs_valid <<= valid
|
| 53 |
self.enable_in <<= enable
|
| 54 |
self.start <<= start
|
|
@@ -85,6 +119,34 @@ class ReluUnit:
|
|
| 85 |
"""
|
| 86 |
return [float(self.dtype(binint=sim.inspect(out.name))) for out in self.outputs]
|
| 87 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 88 |
|
| 89 |
# class ReluUnit:
|
| 90 |
# def __init__(self, size: int, dtype: Type[BaseFloat]):
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
import numpy as np
|
| 3 |
+
from typing import TYPE_CHECKING, Sequence, Type
|
| 4 |
from pyrtl import (
|
| 5 |
WireVector,
|
| 6 |
Input,
|
|
|
|
| 15 |
from ..dtypes.base import BaseFloat
|
| 16 |
|
| 17 |
|
| 18 |
+
@dataclass
|
| 19 |
+
class ReluState:
|
| 20 |
+
start: int
|
| 21 |
+
enable_in: int
|
| 22 |
+
enable_reg: int
|
| 23 |
+
inputs_valid: int
|
| 24 |
+
inputs: np.ndarray
|
| 25 |
+
registers: np.ndarray
|
| 26 |
+
outputs_valid: int
|
| 27 |
+
outputs: np.ndarray
|
| 28 |
+
|
| 29 |
+
def __repr__(self) -> str:
|
| 30 |
+
"""Pretty print the ReLU state"""
|
| 31 |
+
status = "enabled" if self.enable_reg else "disabled"
|
| 32 |
+
valid_str = "(valid)" if self.outputs_valid else "(invalid)"
|
| 33 |
+
|
| 34 |
+
return (
|
| 35 |
+
f"ReLU {status} {valid_str}\n"
|
| 36 |
+
f" Control: start={self.start}, enable_in={self.enable_in}, "
|
| 37 |
+
f"enable_reg={self.enable_reg}, inputs_valid={self.inputs_valid}\n"
|
| 38 |
+
f" Inputs: {np.array2string(self.inputs, precision=4, suppress_small=True)}\n"
|
| 39 |
+
f" Registers: {np.array2string(self.registers, precision=4, suppress_small=True)}\n"
|
| 40 |
+
f" Outputs: {np.array2string(self.outputs, precision=4, suppress_small=True)}"
|
| 41 |
+
)
|
| 42 |
+
|
| 43 |
+
|
| 44 |
class ReluUnit:
|
| 45 |
def __init__(self, size: int, dtype: Type[BaseFloat]):
|
| 46 |
self.size = size
|
|
|
|
| 49 |
# Control signals
|
| 50 |
self.start = WireVector(1) # trigger to latch new enable value
|
| 51 |
self.enable_in = WireVector(1) # input enable value to latch
|
|
|
|
| 52 |
self.enable_reg = Register(1) # stateful enable register
|
| 53 |
+
self.inputs_valid = WireVector(1) # indicates if inputs are valid
|
| 54 |
+
self.valid_reg = Register(1) # stateful valid register
|
| 55 |
+
self.valid_reg.next <<= self.inputs_valid
|
| 56 |
|
| 57 |
# Input and output data
|
| 58 |
+
self.data_in = [WireVector(dtype.bitwidth()) for _ in range(size)]
|
| 59 |
+
self.data_regs = [Register(dtype.bitwidth()) for _ in range(size)]
|
| 60 |
+
for data, reg in zip(self.data_in, self.data_regs):
|
| 61 |
+
reg.next <<= data
|
| 62 |
+
|
| 63 |
+
self.outputs = [self.relu(x) for x in self.data_regs]
|
| 64 |
self.outputs_valid = WireVector(1)
|
| 65 |
+
self.outputs_valid <<= self.valid_reg
|
| 66 |
|
| 67 |
def relu(self, x: WireVector):
|
| 68 |
# Use enable_reg instead of enable wire
|
| 69 |
+
pass_condition = self.valid_reg & (
|
| 70 |
~self.enable_reg | (self.enable_reg & ~x[-1])
|
| 71 |
)
|
| 72 |
return select(pass_condition, x, Const(0, self.dtype.bitwidth()))
|
|
|
|
| 82 |
len(inputs) == self.size
|
| 83 |
), f"Activation module input size mismatch. Expected {self.size}, got {len(inputs)}"
|
| 84 |
for i in range(self.size):
|
| 85 |
+
self.data_in[i] <<= inputs[i]
|
| 86 |
self.inputs_valid <<= valid
|
| 87 |
self.enable_in <<= enable
|
| 88 |
self.start <<= start
|
|
|
|
| 119 |
"""
|
| 120 |
return [float(self.dtype(binint=sim.inspect(out.name))) for out in self.outputs]
|
| 121 |
|
| 122 |
+
def inspect_state(self, sim: Simulation) -> ReluState:
|
| 123 |
+
"""Inspect current state of the ReLU unit."""
|
| 124 |
+
return ReluState(
|
| 125 |
+
start=sim.inspect(self.start.name),
|
| 126 |
+
enable_in=sim.inspect(self.enable_in.name),
|
| 127 |
+
enable_reg=sim.inspect(self.enable_reg.name),
|
| 128 |
+
inputs_valid=sim.inspect(self.inputs_valid.name),
|
| 129 |
+
inputs=np.array(
|
| 130 |
+
[
|
| 131 |
+
float(self.dtype(binint=sim.inspect(inp.name)))
|
| 132 |
+
for inp in self.data_in
|
| 133 |
+
]
|
| 134 |
+
),
|
| 135 |
+
registers=np.array(
|
| 136 |
+
[
|
| 137 |
+
float(self.dtype(binint=sim.inspect(reg.name)))
|
| 138 |
+
for reg in self.data_regs
|
| 139 |
+
]
|
| 140 |
+
),
|
| 141 |
+
outputs_valid=sim.inspect(self.outputs_valid.name),
|
| 142 |
+
outputs=np.array(
|
| 143 |
+
[
|
| 144 |
+
float(self.dtype(binint=sim.inspect(out.name)))
|
| 145 |
+
for out in self.outputs
|
| 146 |
+
]
|
| 147 |
+
),
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
|
| 151 |
# class ReluUnit:
|
| 152 |
# def __init__(self, size: int, dtype: Type[BaseFloat]):
|
hardware_accelerators/rtllib/adders.py
CHANGED
|
@@ -17,6 +17,7 @@ def float_adder(
|
|
| 17 |
float_a: WireVector,
|
| 18 |
float_b: WireVector,
|
| 19 |
dtype: Type[BaseFloat],
|
|
|
|
| 20 |
) -> WireVector:
|
| 21 |
|
| 22 |
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
|
@@ -26,7 +27,7 @@ def float_adder(
|
|
| 26 |
)
|
| 27 |
|
| 28 |
sign_xor, exp_larger, signed_shift, mant_smaller, mant_larger = adder_stage_2(
|
| 29 |
-
sign_a, sign_b, exp_a, exp_b, mantissa_a, mantissa_b, e_bits, m_bits
|
| 30 |
)
|
| 31 |
|
| 32 |
abs_shift = WireVector(e_bits) # , "abs_shift")
|
|
@@ -37,7 +38,7 @@ def float_adder(
|
|
| 37 |
)
|
| 38 |
|
| 39 |
mantissa_sum, is_neg, lzc = adder_stage_4(
|
| 40 |
-
aligned_mant_msb, mant_larger, sign_xor, m_bits
|
| 41 |
)
|
| 42 |
|
| 43 |
final_sign, final_exp, norm_mantissa = adder_stage_5(
|
|
@@ -56,13 +57,52 @@ def float_adder(
|
|
| 56 |
)
|
| 57 |
|
| 58 |
float_result = WireVector(dtype.bitwidth()) # , "float_result")
|
| 59 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 60 |
return float_result
|
| 61 |
|
| 62 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 63 |
### ===================================================================
|
| 64 |
### Simple Pipeline Design
|
| 65 |
### ===================================================================
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 66 |
|
| 67 |
|
| 68 |
class FloatAdderPipelined(SimplePipeline):
|
|
@@ -72,6 +112,7 @@ class FloatAdderPipelined(SimplePipeline):
|
|
| 72 |
float_b: WireVector,
|
| 73 |
w_en: WireVector,
|
| 74 |
dtype: Type[BaseFloat],
|
|
|
|
| 75 |
):
|
| 76 |
"""
|
| 77 |
Initialize a pipelined BFloat16 adder with write enable control.
|
|
@@ -134,17 +175,17 @@ class FloatAdderPipelined(SimplePipeline):
|
|
| 134 |
write enable is not 1 bit
|
| 135 |
"""
|
| 136 |
assert (
|
| 137 |
-
len(float_a) == len(float_b) ==
|
| 138 |
), f"float inputs must be {dtype.bitwidth()} bits"
|
| 139 |
assert len(w_en) == 1, "write enable signal must be 1 bit"
|
| 140 |
-
|
| 141 |
self.e_bits, self.m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 142 |
# Define inputs and outputs
|
| 143 |
self._float_a, self._float_b = float_a, float_b
|
| 144 |
self._write_enable = w_en
|
| 145 |
-
# self._result = pyrtl.Register(self.e_bits + self.m_bits + 1,
|
| 146 |
self._result_out = pyrtl.WireVector(dtype.bitwidth()) # , "_result")
|
| 147 |
-
super(
|
| 148 |
|
| 149 |
@property
|
| 150 |
def result(self):
|
|
@@ -183,6 +224,7 @@ class FloatAdderPipelined(SimplePipeline):
|
|
| 183 |
self.mant_b,
|
| 184 |
self.e_bits,
|
| 185 |
self.m_bits,
|
|
|
|
| 186 |
)
|
| 187 |
|
| 188 |
def stage2(self):
|
|
@@ -219,7 +261,11 @@ class FloatAdderPipelined(SimplePipeline):
|
|
| 219 |
|
| 220 |
# Perform mantissa addition and leading zero detection
|
| 221 |
self.mant_sum, self.is_neg, self.lzc = adder_stage_4(
|
| 222 |
-
self.aligned_mant_msb,
|
|
|
|
|
|
|
|
|
|
|
|
|
| 223 |
)
|
| 224 |
|
| 225 |
def stage4(self):
|
|
|
|
| 17 |
float_a: WireVector,
|
| 18 |
float_b: WireVector,
|
| 19 |
dtype: Type[BaseFloat],
|
| 20 |
+
fast: bool = False,
|
| 21 |
) -> WireVector:
|
| 22 |
|
| 23 |
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
|
|
|
| 27 |
)
|
| 28 |
|
| 29 |
sign_xor, exp_larger, signed_shift, mant_smaller, mant_larger = adder_stage_2(
|
| 30 |
+
sign_a, sign_b, exp_a, exp_b, mantissa_a, mantissa_b, e_bits, m_bits, fast
|
| 31 |
)
|
| 32 |
|
| 33 |
abs_shift = WireVector(e_bits) # , "abs_shift")
|
|
|
|
| 38 |
)
|
| 39 |
|
| 40 |
mantissa_sum, is_neg, lzc = adder_stage_4(
|
| 41 |
+
aligned_mant_msb, mant_larger, sign_xor, m_bits, fast
|
| 42 |
)
|
| 43 |
|
| 44 |
final_sign, final_exp, norm_mantissa = adder_stage_5(
|
|
|
|
| 57 |
)
|
| 58 |
|
| 59 |
float_result = WireVector(dtype.bitwidth()) # , "float_result")
|
| 60 |
+
|
| 61 |
+
# Zero detection logic
|
| 62 |
+
a_is_zero = ~pyrtl.or_all_bits(float_a[:-1])
|
| 63 |
+
b_is_zero = ~pyrtl.or_all_bits(float_b[:-1])
|
| 64 |
+
|
| 65 |
+
with pyrtl.conditional_assignment:
|
| 66 |
+
with a_is_zero:
|
| 67 |
+
float_result |= float_b
|
| 68 |
+
with b_is_zero:
|
| 69 |
+
float_result |= float_a
|
| 70 |
+
with pyrtl.otherwise:
|
| 71 |
+
float_result |= pyrtl.concat(final_sign, final_exp, norm_mantissa)
|
| 72 |
+
|
| 73 |
return float_result
|
| 74 |
|
| 75 |
|
| 76 |
+
def float_adder_fast_unstable(
|
| 77 |
+
float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]
|
| 78 |
+
) -> WireVector:
|
| 79 |
+
return float_adder(float_a, float_b, dtype, fast=True)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
### ===================================================================
|
| 83 |
### Simple Pipeline Design
|
| 84 |
### ===================================================================
|
| 85 |
+
# TODO: add zero detection logic
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
def float_adder_pipelined(
|
| 89 |
+
float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat], fast: bool = False
|
| 90 |
+
) -> WireVector:
|
| 91 |
+
w_en = pyrtl.Input(1)
|
| 92 |
+
w_en.name = w_en.name.replace("tmp", "adder_w_en_in")
|
| 93 |
+
adder = FloatAdderPipelined(float_a, float_b, w_en, dtype, fast=fast)
|
| 94 |
+
return adder._result_out
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def float_adder_pipelined_fast_unstable(
|
| 98 |
+
float_a: WireVector,
|
| 99 |
+
float_b: WireVector,
|
| 100 |
+
dtype: Type[BaseFloat],
|
| 101 |
+
) -> WireVector:
|
| 102 |
+
w_en = pyrtl.Input(1)
|
| 103 |
+
w_en.name = w_en.name.replace("tmp", "adder_w_en_in")
|
| 104 |
+
adder = FloatAdderPipelined(float_a, float_b, w_en, dtype, fast=True)
|
| 105 |
+
return adder._result_out
|
| 106 |
|
| 107 |
|
| 108 |
class FloatAdderPipelined(SimplePipeline):
|
|
|
|
| 112 |
float_b: WireVector,
|
| 113 |
w_en: WireVector,
|
| 114 |
dtype: Type[BaseFloat],
|
| 115 |
+
fast: bool = False,
|
| 116 |
):
|
| 117 |
"""
|
| 118 |
Initialize a pipelined BFloat16 adder with write enable control.
|
|
|
|
| 175 |
write enable is not 1 bit
|
| 176 |
"""
|
| 177 |
assert (
|
| 178 |
+
len(float_a) == len(float_b) == dtype.bitwidth()
|
| 179 |
), f"float inputs must be {dtype.bitwidth()} bits"
|
| 180 |
assert len(w_en) == 1, "write enable signal must be 1 bit"
|
| 181 |
+
self._fast = fast
|
| 182 |
self.e_bits, self.m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 183 |
# Define inputs and outputs
|
| 184 |
self._float_a, self._float_b = float_a, float_b
|
| 185 |
self._write_enable = w_en
|
| 186 |
+
# self._result = pyrtl.Register(self.e_bits + self.m_bits + 1, "result")
|
| 187 |
self._result_out = pyrtl.WireVector(dtype.bitwidth()) # , "_result")
|
| 188 |
+
super().__init__("float_adder")
|
| 189 |
|
| 190 |
@property
|
| 191 |
def result(self):
|
|
|
|
| 224 |
self.mant_b,
|
| 225 |
self.e_bits,
|
| 226 |
self.m_bits,
|
| 227 |
+
self._fast,
|
| 228 |
)
|
| 229 |
|
| 230 |
def stage2(self):
|
|
|
|
| 261 |
|
| 262 |
# Perform mantissa addition and leading zero detection
|
| 263 |
self.mant_sum, self.is_neg, self.lzc = adder_stage_4(
|
| 264 |
+
self.aligned_mant_msb,
|
| 265 |
+
self.mant_larger,
|
| 266 |
+
self.sign_xor,
|
| 267 |
+
self.m_bits,
|
| 268 |
+
self._fast,
|
| 269 |
)
|
| 270 |
|
| 271 |
def stage4(self):
|
hardware_accelerators/rtllib/legacy.py
CHANGED
|
@@ -1,11 +1,82 @@
|
|
|
|
|
| 1 |
import pyrtl
|
|
|
|
| 2 |
from pyrtl.rtllib.adders import carrysave_adder, kogge_stone
|
| 3 |
|
|
|
|
|
|
|
| 4 |
from .utils.lmul_utils import get_combined_offset
|
| 5 |
|
|
|
|
| 6 |
###########################
|
| 7 |
# Old code below
|
| 8 |
###########################
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
# BF16 Naive Combinatorial
|
|
|
|
| 1 |
+
from typing import Type
|
| 2 |
import pyrtl
|
| 3 |
+
from pyrtl import WireVector
|
| 4 |
from pyrtl.rtllib.adders import carrysave_adder, kogge_stone
|
| 5 |
|
| 6 |
+
from ..dtypes.base import BaseFloat
|
| 7 |
+
|
| 8 |
from .utils.lmul_utils import get_combined_offset
|
| 9 |
|
| 10 |
+
|
| 11 |
###########################
|
| 12 |
# Old code below
|
| 13 |
###########################
|
| 14 |
+
def lmul_simple(
|
| 15 |
+
float_a: WireVector,
|
| 16 |
+
float_b: WireVector,
|
| 17 |
+
dtype: Type[BaseFloat],
|
| 18 |
+
):
|
| 19 |
+
"""Linear time complexity float multiply unit in the simplest configuration."""
|
| 20 |
+
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 21 |
+
em_bits = e_bits + m_bits
|
| 22 |
+
sign_out = float_a[em_bits] ^ float_b[em_bits]
|
| 23 |
+
|
| 24 |
+
unsigned_offset = pyrtl.Const(get_combined_offset(e_bits, m_bits), em_bits)
|
| 25 |
+
result_sum = float_a[:em_bits] + float_b[:em_bits] - unsigned_offset
|
| 26 |
+
|
| 27 |
+
fp_out = WireVector(bitwidth=em_bits + 1)
|
| 28 |
+
fp_out <<= pyrtl.concat(sign_out, pyrtl.truncate(result_sum, em_bits))
|
| 29 |
+
return fp_out
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def lmul_fast(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
|
| 33 |
+
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 34 |
+
em_bits = e_bits + m_bits
|
| 35 |
+
sign_a = float_a[em_bits]
|
| 36 |
+
sign_b = float_b[em_bits]
|
| 37 |
+
exp_mantissa_a = float_a[:em_bits]
|
| 38 |
+
exp_mantissa_b = float_b[:em_bits]
|
| 39 |
+
fp_out = WireVector(em_bits + 1)
|
| 40 |
+
|
| 41 |
+
# Calculate result sign
|
| 42 |
+
result_sign = sign_a ^ sign_b
|
| 43 |
+
|
| 44 |
+
# Add exp_mantissa parts using kogge_stone adder (faster than ripple)
|
| 45 |
+
# exp_mantissa_sum = kogge_stone(exp_mantissa_a, exp_mantissa_b)
|
| 46 |
+
|
| 47 |
+
# Get the combined offset-bias constant
|
| 48 |
+
OFFSET_MINUS_BIAS = pyrtl.Const(
|
| 49 |
+
get_combined_offset(e_bits, m_bits, True), bitwidth=em_bits
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# Add offset-bias value - this will be 8 bits including carry
|
| 53 |
+
# final_sum = kogge_stone(exp_mantissa_sum, OFFSET_MINUS_BIAS)
|
| 54 |
+
|
| 55 |
+
final_sum = carrysave_adder(
|
| 56 |
+
exp_mantissa_a, exp_mantissa_b, OFFSET_MINUS_BIAS, final_adder=kogge_stone
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# Select result based on carry and MSB:
|
| 60 |
+
# carry=1: overflow -> 0x7F
|
| 61 |
+
# carry=0, msb=0: underflow -> 0x00
|
| 62 |
+
# carry=0, msb=1: normal -> result_bits
|
| 63 |
+
|
| 64 |
+
MAX_VALUE = pyrtl.Const(2**em_bits - 1, bitwidth=em_bits) # , name="max_value")
|
| 65 |
+
|
| 66 |
+
if e_bits == 4 and m_bits == 3:
|
| 67 |
+
MAX_VALUE = pyrtl.Const(0x7F, 7)
|
| 68 |
+
|
| 69 |
+
mantissa_result = pyrtl.mux(
|
| 70 |
+
final_sum[em_bits:],
|
| 71 |
+
pyrtl.Const(0, bitwidth=em_bits),
|
| 72 |
+
final_sum[:em_bits],
|
| 73 |
+
default=MAX_VALUE,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# Combine sign and result
|
| 77 |
+
fp_out <<= pyrtl.concat(result_sign, mantissa_result)
|
| 78 |
+
|
| 79 |
+
return fp_out
|
| 80 |
|
| 81 |
|
| 82 |
# BF16 Naive Combinatorial
|
hardware_accelerators/rtllib/lmul.py
CHANGED
|
@@ -1,92 +1,92 @@
|
|
| 1 |
from typing import Type
|
| 2 |
|
| 3 |
import pyrtl
|
| 4 |
-
from pyrtl import WireVector
|
| 5 |
-
from pyrtl.rtllib.adders import carrysave_adder, kogge_stone
|
| 6 |
|
| 7 |
from ..dtypes import BaseFloat, Float8
|
| 8 |
-
from .utils.lmul_utils import get_combined_offset
|
| 9 |
|
| 10 |
|
| 11 |
-
def
|
| 12 |
-
float_a: WireVector,
|
| 13 |
-
float_b: WireVector,
|
| 14 |
-
dtype: Type[BaseFloat],
|
| 15 |
-
):
|
| 16 |
-
"""Linear time complexity float multiply unit in the simplest configuration."""
|
| 17 |
-
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 18 |
-
em_bits = e_bits + m_bits
|
| 19 |
-
sign_out = float_a[em_bits] ^ float_b[em_bits]
|
| 20 |
-
|
| 21 |
-
unsigned_offset = pyrtl.Const(get_combined_offset(e_bits, m_bits), em_bits)
|
| 22 |
-
result_sum = float_a[:em_bits] + float_b[:em_bits] - unsigned_offset
|
| 23 |
-
|
| 24 |
-
fp_out = WireVector(bitwidth=em_bits + 1)
|
| 25 |
-
fp_out <<= pyrtl.concat(sign_out, pyrtl.truncate(result_sum, em_bits))
|
| 26 |
-
return fp_out
|
| 27 |
-
|
| 28 |
-
|
| 29 |
-
def lmul_fast(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
|
| 30 |
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 31 |
em_bits = e_bits + m_bits
|
| 32 |
sign_a = float_a[em_bits]
|
| 33 |
sign_b = float_b[em_bits]
|
|
|
|
|
|
|
| 34 |
exp_mantissa_a = float_a[:em_bits]
|
| 35 |
exp_mantissa_b = float_b[:em_bits]
|
| 36 |
-
fp_out = WireVector(em_bits + 1)
|
| 37 |
|
| 38 |
-
|
| 39 |
-
|
|
|
|
|
|
|
| 40 |
|
| 41 |
-
|
| 42 |
-
|
| 43 |
|
| 44 |
-
|
| 45 |
-
|
| 46 |
-
|
| 47 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 48 |
|
| 49 |
-
|
| 50 |
-
# final_sum = kogge_stone(exp_mantissa_sum, OFFSET_MINUS_BIAS)
|
| 51 |
|
| 52 |
-
final_sum = carrysave_adder(
|
| 53 |
-
exp_mantissa_a, exp_mantissa_b, OFFSET_MINUS_BIAS, final_adder=kogge_stone
|
| 54 |
-
)
|
| 55 |
|
| 56 |
-
|
| 57 |
-
|
| 58 |
-
# carry=0, msb=0: underflow -> 0x00
|
| 59 |
-
# carry=0, msb=1: normal -> result_bits
|
| 60 |
|
| 61 |
-
MAX_VALUE = pyrtl.Const(2**em_bits - 1, bitwidth=em_bits) # , name="max_value")
|
| 62 |
|
| 63 |
-
|
| 64 |
-
|
| 65 |
|
| 66 |
-
mantissa_result = pyrtl.mux(
|
| 67 |
-
final_sum[em_bits:],
|
| 68 |
-
pyrtl.Const(0, bitwidth=em_bits),
|
| 69 |
-
final_sum[:em_bits],
|
| 70 |
-
default=MAX_VALUE,
|
| 71 |
-
)
|
| 72 |
|
| 73 |
-
|
| 74 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 75 |
|
| 76 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 77 |
|
| 78 |
|
| 79 |
-
# Float8 fast pipelined lmul
|
| 80 |
class LmulPipelined:
|
| 81 |
def __init__(
|
| 82 |
self,
|
| 83 |
float_a: WireVector,
|
| 84 |
float_b: WireVector,
|
| 85 |
dtype: Type[BaseFloat],
|
|
|
|
| 86 |
):
|
| 87 |
self.e_bits = dtype.exponent_bits()
|
| 88 |
self.m_bits = dtype.mantissa_bits()
|
| 89 |
self.em_bits = dtype.bitwidth() - 1
|
|
|
|
| 90 |
|
| 91 |
# Inputs and Outputs
|
| 92 |
assert (
|
|
@@ -137,13 +137,16 @@ class LmulPipelined:
|
|
| 137 |
# Calculate and register sign
|
| 138 |
self.reg_sign.next <<= sign_a ^ sign_b
|
| 139 |
|
| 140 |
-
#
|
| 141 |
-
|
| 142 |
-
|
| 143 |
-
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
|
|
|
|
|
|
|
|
|
|
| 147 |
|
| 148 |
self.reg_final_sum.next <<= final_sum
|
| 149 |
|
|
|
|
| 1 |
from typing import Type
|
| 2 |
|
| 3 |
import pyrtl
|
| 4 |
+
from pyrtl import WireVector, conditional_assignment
|
| 5 |
+
from pyrtl.rtllib.adders import carrysave_adder, kogge_stone, fast_group_adder
|
| 6 |
|
| 7 |
from ..dtypes import BaseFloat, Float8
|
| 8 |
+
from .utils.lmul_utils import get_combined_offset, lmul_offset_rtl
|
| 9 |
|
| 10 |
|
| 11 |
+
def lmul(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat], fast=False):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
e_bits, m_bits = dtype.exponent_bits(), dtype.mantissa_bits()
|
| 13 |
em_bits = e_bits + m_bits
|
| 14 |
sign_a = float_a[em_bits]
|
| 15 |
sign_b = float_b[em_bits]
|
| 16 |
+
exp_a = float_a[m_bits:-1]
|
| 17 |
+
exp_b = float_b[m_bits:-1]
|
| 18 |
exp_mantissa_a = float_a[:em_bits]
|
| 19 |
exp_mantissa_b = float_b[:em_bits]
|
|
|
|
| 20 |
|
| 21 |
+
zero_or_subnormal = WireVector(1)
|
| 22 |
+
final_sum = WireVector(em_bits + 2)
|
| 23 |
+
carry_msb = WireVector(2)
|
| 24 |
+
fp_out = WireVector(dtype.bitwidth())
|
| 25 |
|
| 26 |
+
OFFSET_MINUS_BIAS = lmul_offset_rtl(dtype)
|
| 27 |
+
MAX_VALUE = pyrtl.Const(dtype.binary_max(), bitwidth=em_bits)
|
| 28 |
|
| 29 |
+
if fast:
|
| 30 |
+
final_sum <<= carrysave_adder(
|
| 31 |
+
exp_mantissa_a, exp_mantissa_b, OFFSET_MINUS_BIAS, final_adder=kogge_stone
|
| 32 |
+
)
|
| 33 |
+
else:
|
| 34 |
+
final_sum <<= exp_mantissa_a + exp_mantissa_b + OFFSET_MINUS_BIAS
|
| 35 |
+
|
| 36 |
+
carry_msb <<= final_sum[em_bits:]
|
| 37 |
+
zero_or_subnormal <<= ~pyrtl.or_all_bits(exp_a) | ~pyrtl.or_all_bits(exp_b)
|
| 38 |
+
|
| 39 |
+
with conditional_assignment:
|
| 40 |
+
with zero_or_subnormal:
|
| 41 |
+
fp_out |= 0
|
| 42 |
+
with carry_msb == 0:
|
| 43 |
+
fp_out |= 0
|
| 44 |
+
with carry_msb == 1:
|
| 45 |
+
fp_out |= pyrtl.concat(sign_a ^ sign_b, final_sum[:em_bits])
|
| 46 |
+
with pyrtl.otherwise:
|
| 47 |
+
fp_out |= pyrtl.concat(sign_a ^ sign_b, MAX_VALUE)
|
| 48 |
|
| 49 |
+
return fp_out
|
|
|
|
| 50 |
|
|
|
|
|
|
|
|
|
|
| 51 |
|
| 52 |
+
def lmul_simple(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
|
| 53 |
+
return lmul(float_a, float_b, dtype, fast=False)
|
|
|
|
|
|
|
| 54 |
|
|
|
|
| 55 |
|
| 56 |
+
def lmul_fast(float_a: WireVector, float_b: WireVector, dtype: Type[BaseFloat]):
|
| 57 |
+
return lmul(float_a, float_b, dtype, fast=True)
|
| 58 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 59 |
|
| 60 |
+
def lmul_pipelined(
|
| 61 |
+
float_a: WireVector,
|
| 62 |
+
float_b: WireVector,
|
| 63 |
+
dtype: Type[BaseFloat],
|
| 64 |
+
) -> WireVector:
|
| 65 |
+
mult = LmulPipelined(float_a, float_b, dtype)
|
| 66 |
+
return mult.output_reg
|
| 67 |
|
| 68 |
+
|
| 69 |
+
def lmul_pipelined_fast(
|
| 70 |
+
float_a: WireVector,
|
| 71 |
+
float_b: WireVector,
|
| 72 |
+
dtype: Type[BaseFloat],
|
| 73 |
+
) -> WireVector:
|
| 74 |
+
mult = LmulPipelined(float_a, float_b, dtype, fast=True)
|
| 75 |
+
return mult.output_reg
|
| 76 |
|
| 77 |
|
|
|
|
| 78 |
class LmulPipelined:
|
| 79 |
def __init__(
|
| 80 |
self,
|
| 81 |
float_a: WireVector,
|
| 82 |
float_b: WireVector,
|
| 83 |
dtype: Type[BaseFloat],
|
| 84 |
+
fast: bool = False,
|
| 85 |
):
|
| 86 |
self.e_bits = dtype.exponent_bits()
|
| 87 |
self.m_bits = dtype.mantissa_bits()
|
| 88 |
self.em_bits = dtype.bitwidth() - 1
|
| 89 |
+
self._fast = fast
|
| 90 |
|
| 91 |
# Inputs and Outputs
|
| 92 |
assert (
|
|
|
|
| 137 |
# Calculate and register sign
|
| 138 |
self.reg_sign.next <<= sign_a ^ sign_b
|
| 139 |
|
| 140 |
+
# Add the floating point numbers with special lmul offset
|
| 141 |
+
if self._fast:
|
| 142 |
+
final_sum = carrysave_adder(
|
| 143 |
+
exp_mantissa_a,
|
| 144 |
+
exp_mantissa_b,
|
| 145 |
+
self.OFFSET_MINUS_BIAS,
|
| 146 |
+
final_adder=kogge_stone,
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
final_sum = exp_mantissa_a + exp_mantissa_b + self.OFFSET_MINUS_BIAS
|
| 150 |
|
| 151 |
self.reg_final_sum.next <<= final_sum
|
| 152 |
|