| import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), ".."))) | |
| import torch | |
| from bit_transformer import BitTransformerLM, distill_step, TelemetryLog | |
| def test_distill_prunes_weights(): | |
| model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8) | |
| attn = torch.rand(2, 4, 8, 8) | |
| telemetry = TelemetryLog(attention_maps=attn) | |
| pruned = distill_step(model, scale=0.5, telemetry=telemetry) | |
| total = 0 | |
| zeros = 0 | |
| for m in pruned.modules(): | |
| if isinstance(m, torch.nn.Linear): | |
| w = m.weight.detach() | |
| total += w.numel() | |
| zeros += (w == 0).sum().item() | |
| assert zeros >= int(total * 0.5) | |