| """Deploy a trained SHD model to the Neurocore SDK or evaluate quantization.
|
|
|
| Loads a PyTorch checkpoint from shd_train.py, quantizes weights to int16,
|
| and evaluates accuracy with quantized weights. Also builds an SDK Network
|
| for deployment to the FPGA via CUBA neurons.
|
|
|
| Supports both LIF and adLIF checkpoints. For adLIF, adaptation parameters
|
| (rho, beta_a) are training-only; only alpha (membrane decay) deploys as decay_v.
|
|
|
| Usage:
|
| python shd_deploy.py --checkpoint shd_model.pt --data-dir data/shd
|
| python shd_deploy.py --checkpoint shd_adlif_model.pt --neuron-type adlif
|
| """
|
|
|
| import os
|
| import sys
|
| import argparse
|
| import numpy as np
|
|
|
| import torch
|
| from torch.utils.data import DataLoader
|
|
|
|
|
| _SDK_DIR = os.path.normpath(os.path.join(os.path.dirname(__file__), ".."))
|
| if _SDK_DIR not in sys.path:
|
| sys.path.insert(0, _SDK_DIR)
|
| sys.path.insert(0, os.path.dirname(__file__))
|
|
|
| from shd_loader import SHDDataset, collate_fn, N_CHANNELS, N_CLASSES
|
| from shd_train import SHDSNN
|
|
|
| from neurocore import Network
|
| from neurocore.constants import WEIGHT_MIN, WEIGHT_MAX
|
|
|
|
|
| def quantize_weights(w_float, threshold_float, threshold_hw=1000):
|
| """Quantize float weight matrix to int16 for hardware deployment.
|
|
|
| Maps float weights so hardware dynamics match training dynamics:
|
| weight_hw = round(w_float * threshold_hw / threshold_float)
|
| clamped to [WEIGHT_MIN, WEIGHT_MAX] = [-32768, 32767]
|
|
|
| Args:
|
| w_float: (out, in) float32 weight matrix from nn.Linear
|
| threshold_float: threshold used in training (e.g. 1.0)
|
| threshold_hw: hardware threshold (default 1000)
|
|
|
| Returns:
|
| w_int: (in, out) int32 weight matrix (transposed for src->tgt convention)
|
| """
|
| scale = threshold_hw / threshold_float
|
| w_scaled = w_float * scale
|
| w_int = np.clip(np.round(w_scaled), WEIGHT_MIN, WEIGHT_MAX).astype(np.int32)
|
|
|
| return w_int.T
|
|
|
|
|
| def detect_neuron_type(checkpoint):
|
| """Auto-detect neuron type from checkpoint state dict keys."""
|
| state = checkpoint['model_state_dict']
|
| if 'lif1.alpha_raw' in state:
|
| return 'adlif'
|
| return 'lif'
|
|
|
|
|
| def compute_hardware_params(checkpoint, threshold_hw=1000, neuron_type=None):
|
| """Compute hardware neuron parameters from trained model.
|
|
|
| Maps membrane decay to CUBA neuron decay_v:
|
| decay_v = round(decay * 4096) (12-bit fractional)
|
|
|
| For LIF: decay = beta (from lif1.beta_raw)
|
| For adLIF: decay = alpha (from lif1.alpha_raw)
|
| adLIF adaptation params (rho, beta_a) are training-only.
|
|
|
| Returns:
|
| dict with hardware parameters for each layer
|
| """
|
| state = checkpoint['model_state_dict']
|
| if neuron_type is None:
|
| neuron_type = detect_neuron_type(checkpoint)
|
|
|
| params = {'neuron_type': neuron_type}
|
|
|
| if neuron_type == 'adlif':
|
|
|
| alpha_raw = state.get('lif1.alpha_raw', None)
|
| if alpha_raw is not None:
|
| alpha = torch.sigmoid(alpha_raw).cpu().numpy()
|
| params['hidden_alpha_mean'] = float(alpha.mean())
|
| params['hidden_alpha_std'] = float(alpha.std())
|
| params['hidden_decay_v'] = int(round(alpha.mean() * 4096))
|
|
|
| params['hidden_beta_mean'] = float(alpha.mean())
|
|
|
|
|
| rho_raw = state.get('lif1.rho_raw', None)
|
| if rho_raw is not None:
|
| rho = torch.sigmoid(rho_raw).cpu().numpy()
|
| params['hidden_rho_mean'] = float(rho.mean())
|
| params['hidden_rho_note'] = 'training-only (not deployed)'
|
|
|
| beta_a_raw = state.get('lif1.beta_a_raw', None)
|
| if beta_a_raw is not None:
|
| import torch.nn.functional as F_
|
| beta_a = F_.softplus(beta_a_raw).cpu().numpy()
|
| params['hidden_beta_a_mean'] = float(beta_a.mean())
|
| params['hidden_beta_a_note'] = 'training-only (not deployed)'
|
| else:
|
|
|
| beta_hid_raw = state.get('lif1.beta_raw', None)
|
| if beta_hid_raw is not None:
|
| beta_hid = torch.sigmoid(beta_hid_raw).cpu().numpy()
|
| params['hidden_beta_mean'] = float(beta_hid.mean())
|
| params['hidden_beta_std'] = float(beta_hid.std())
|
| params['hidden_decay_v'] = int(round(beta_hid.mean() * 4096))
|
|
|
|
|
| beta_out_raw = state.get('lif2.beta_raw', None)
|
| if beta_out_raw is not None:
|
| beta_out = torch.sigmoid(beta_out_raw).cpu().numpy()
|
| params['output_beta_mean'] = float(beta_out.mean())
|
| params['output_beta_std'] = float(beta_out.std())
|
| params['output_decay_v'] = int(round(beta_out.mean() * 4096))
|
|
|
| params['threshold_hw'] = threshold_hw
|
| return params
|
|
|
|
|
| def build_sdk_network(checkpoint, threshold_hw=1000):
|
| """Build SDK Network from a trained PyTorch checkpoint.
|
|
|
| Uses subtractive leak as approximation for multiplicative decay.
|
| True hardware deployment would use CUBA mode with decay_v.
|
|
|
| Returns:
|
| net: Network ready for deploy()
|
| n_hidden: hidden layer size (for reporting)
|
| """
|
| args = checkpoint['args']
|
| threshold_float = args['threshold']
|
| n_hidden = args['hidden']
|
|
|
| state = checkpoint['model_state_dict']
|
| w_fc1 = state['fc1.weight'].cpu().numpy()
|
| w_fc2 = state['fc2.weight'].cpu().numpy()
|
| w_rec = state['fc_rec.weight'].cpu().numpy()
|
|
|
|
|
| wm_fc1 = quantize_weights(w_fc1, threshold_float, threshold_hw)
|
| wm_fc2 = quantize_weights(w_fc2, threshold_float, threshold_hw)
|
| wm_rec = quantize_weights(w_rec, threshold_float, threshold_hw)
|
|
|
|
|
| hw = compute_hardware_params(checkpoint, threshold_hw)
|
| leak_hid = max(1, int(round((1 - hw.get('hidden_beta_mean', 0.95)) * threshold_hw)))
|
| leak_out = max(1, int(round((1 - hw.get('output_beta_mean', 0.9)) * threshold_hw)))
|
|
|
|
|
| net = Network()
|
| inp = net.population(N_CHANNELS,
|
| params={'threshold': 65535, 'leak': 0, 'refrac': 0},
|
| label="input")
|
| hid = net.population(n_hidden,
|
| params={'threshold': threshold_hw, 'leak': leak_hid, 'refrac': 0},
|
| label="hidden")
|
| out = net.population(N_CLASSES,
|
| params={'threshold': threshold_hw, 'leak': leak_out, 'refrac': 0},
|
| label="output")
|
|
|
| net.connect(inp, hid, weight_matrix=wm_fc1)
|
| net.connect(hid, out, weight_matrix=wm_fc2)
|
| net.connect(hid, hid, weight_matrix=wm_rec)
|
|
|
|
|
| nonzero_fc1 = np.count_nonzero(wm_fc1)
|
| nonzero_fc2 = np.count_nonzero(wm_fc2)
|
| nonzero_rec = np.count_nonzero(wm_rec)
|
| total_conn = nonzero_fc1 + nonzero_fc2 + nonzero_rec
|
| print(f"Quantized weights (threshold_hw={threshold_hw}):")
|
| print(f" fc1: {wm_fc1.shape}, {nonzero_fc1:,} nonzero, "
|
| f"range [{wm_fc1.min()}, {wm_fc1.max()}]")
|
| print(f" fc2: {wm_fc2.shape}, {nonzero_fc2:,} nonzero, "
|
| f"range [{wm_fc2.min()}, {wm_fc2.max()}]")
|
| print(f" rec: {wm_rec.shape}, {nonzero_rec:,} nonzero, "
|
| f"range [{wm_rec.min()}, {wm_rec.max()}]")
|
| print(f" Total connections: {total_conn:,}")
|
| if 'hidden_decay_v' in hw:
|
| print(f" Hardware decay_v (hidden): {hw['hidden_decay_v']} "
|
| f"(beta={hw['hidden_beta_mean']:.4f})")
|
| if 'output_decay_v' in hw:
|
| print(f" Hardware decay_v (output): {hw['output_decay_v']} "
|
| f"(beta={hw['output_beta_mean']:.4f})")
|
|
|
| return net, n_hidden
|
|
|
|
|
| def run_pytorch_quantized_inference(checkpoint, test_ds, device='cpu',
|
| neuron_type=None):
|
| """Run inference with quantized weights in PyTorch (for comparison).
|
|
|
| Loads the model, replaces float weights with quantized int versions
|
| (converted back to float), and runs normal forward pass.
|
| """
|
| args = checkpoint['args']
|
| threshold_float = args['threshold']
|
| threshold_hw = 1000
|
| if neuron_type is None:
|
| neuron_type = args.get('neuron_type', detect_neuron_type(checkpoint))
|
|
|
| model = SHDSNN(
|
| n_hidden=args['hidden'],
|
| threshold=args['threshold'],
|
| beta_hidden=args.get('beta_hidden', 0.95),
|
| beta_out=args.get('beta_out', 0.9),
|
| dropout=0.0,
|
| neuron_type=neuron_type,
|
| alpha_init=args.get('alpha_init', 0.90),
|
| rho_init=args.get('rho_init', 0.85),
|
| beta_a_init=args.get('beta_a_init', 1.8),
|
| ).to(device)
|
| model.load_state_dict(checkpoint['model_state_dict'])
|
|
|
|
|
| scale = threshold_hw / threshold_float
|
| skip_keys = ('beta', 'alpha', 'rho', 'threshold_base')
|
| with torch.no_grad():
|
| for name, param in model.named_parameters():
|
| if 'weight' in name and not any(k in name for k in skip_keys):
|
| q = torch.round(param * scale).clamp(WEIGHT_MIN, WEIGHT_MAX) / scale
|
| param.copy_(q)
|
|
|
| model.eval()
|
| loader = DataLoader(test_ds, batch_size=128, shuffle=False,
|
| collate_fn=collate_fn, num_workers=0)
|
|
|
| correct = 0
|
| total = 0
|
| with torch.no_grad():
|
| for inputs, labels in loader:
|
| inputs, labels = inputs.to(device), labels.to(device)
|
| output = model(inputs)
|
| correct += (output.argmax(1) == labels).sum().item()
|
| total += inputs.size(0)
|
|
|
| acc = correct / total
|
| print(f" PyTorch quantized accuracy: {correct}/{total} = {acc*100:.1f}%")
|
| return acc
|
|
|
|
|
| def main():
|
| parser = argparse.ArgumentParser(description="Deploy trained SHD model")
|
| parser.add_argument("--checkpoint", default="shd_model.pt",
|
| help="Path to trained model checkpoint")
|
| parser.add_argument("--data-dir", default="data/shd")
|
| parser.add_argument("--n-samples", type=int, default=None,
|
| help="Limit test samples (default: all)")
|
| parser.add_argument("--threshold-hw", type=int, default=1000)
|
| parser.add_argument("--dt", type=float, default=4e-3)
|
| parser.add_argument("--neuron-type", choices=["lif", "adlif"], default=None,
|
| help="Neuron model (auto-detected from checkpoint if omitted)")
|
| args = parser.parse_args()
|
|
|
| print(f"Loading checkpoint: {args.checkpoint}")
|
| ckpt = torch.load(args.checkpoint, map_location='cpu', weights_only=False)
|
| train_args = ckpt['args']
|
|
|
|
|
| neuron_type = args.neuron_type or train_args.get('neuron_type', detect_neuron_type(ckpt))
|
| print(f" Training accuracy: {ckpt['test_acc']*100:.1f}%")
|
| print(f" Architecture: {N_CHANNELS}->{train_args['hidden']}->{N_CLASSES} ({neuron_type.upper()})")
|
|
|
| print("\nLoading test dataset...")
|
| test_ds = SHDDataset(args.data_dir, "test", dt=args.dt)
|
| print(f" {len(test_ds)} samples, {test_ds.n_bins} time bins")
|
|
|
|
|
| print("\n--- Hardware parameter mapping ---")
|
| hw_params = compute_hardware_params(ckpt, args.threshold_hw, neuron_type)
|
| for k, v in sorted(hw_params.items()):
|
| print(f" {k}: {v}")
|
|
|
|
|
| print("\n--- PyTorch quantized inference ---")
|
| pytorch_acc = run_pytorch_quantized_inference(ckpt, test_ds,
|
| neuron_type=neuron_type)
|
|
|
|
|
| print("\n--- SDK network summary ---")
|
| net, n_hidden = build_sdk_network(ckpt, threshold_hw=args.threshold_hw)
|
|
|
|
|
| print("\n=== Results ===")
|
| print(f" PyTorch float accuracy: {ckpt['test_acc']*100:.1f}%")
|
| print(f" PyTorch quantized accuracy: {pytorch_acc*100:.1f}%")
|
| gap = abs(ckpt['test_acc'] - pytorch_acc) * 100
|
| print(f" Quantization loss: {gap:.1f}%")
|
| print(f"\n Hardware deployment: CUBA mode (decay_v={hw_params.get('hidden_decay_v', 'N/A')})")
|
| print(f" Total synapses: {sum(1 for c in net.connections for _ in range(1)):,}")
|
|
|
|
|
| if __name__ == "__main__":
|
| main()
|
|
|