WCNegentropy commited on
Commit
bc0d887
·
verified ·
1 Parent(s): 6599dbe

Remove nested directory: BitTransformerLM/tests/test_model.py

Browse files
Files changed (1) hide show
  1. BitTransformerLM/tests/test_model.py +0 -304
BitTransformerLM/tests/test_model.py DELETED
@@ -1,304 +0,0 @@
1
- import os, sys; sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), "..")))
2
- from bit_transformer import (
3
- BitTransformerLM,
4
- hil_safe_inference,
5
- text_to_bits,
6
- bits_to_text,
7
- plot_telemetry,
8
- infer_long_sequence,
9
- diffusion_inference,
10
- compress_bits,
11
- )
12
- from bit_transformer.safety import SafetyGate
13
- import torch
14
- import torch.nn.functional as F
15
- import torch.nn as nn
16
- import pytest
17
-
18
- def test_forward_pass():
19
- B, L = 2, 8
20
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=L)
21
- bits = torch.randint(0, 2, (B, L), dtype=torch.long)
22
- logits, telemetry = model(bits)
23
- assert logits.shape == (B, L, 2)
24
- required_keys = {
25
- "negentropy_input",
26
- "lz_complexity_input",
27
- "negentropy_logits",
28
- "lz_complexity_logits",
29
- "symbiosis_kl",
30
- "symbiosis_score",
31
- "attention_entropy",
32
- "attention_entropy_mean",
33
- }
34
- assert required_keys.issubset(telemetry.keys())
35
- pred = logits[:, :-1, :].reshape(-1, 2)
36
- target = bits[:, 1:].reshape(-1)
37
- loss = F.cross_entropy(pred, target)
38
- assert torch.isfinite(loss)
39
-
40
-
41
- def test_autocast_forward():
42
- model = BitTransformerLM(
43
- d_model=32,
44
- nhead=4,
45
- num_layers=1,
46
- dim_feedforward=64,
47
- max_seq_len=8,
48
- use_autocast=True,
49
- )
50
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
51
- logits, _ = model(bits)
52
- assert logits.shape == (1, 8, 2)
53
-
54
-
55
- def test_act_forward():
56
- model = BitTransformerLM(
57
- d_model=32,
58
- nhead=4,
59
- num_layers=2,
60
- dim_feedforward=64,
61
- max_seq_len=8,
62
- use_act=True,
63
- )
64
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
65
- logits, tele = model(bits)
66
- assert logits.shape == (1, 8, 2)
67
- assert "halt_probs" in tele
68
-
69
-
70
- def test_act_skips_layers():
71
- model = BitTransformerLM(
72
- d_model=16,
73
- nhead=4,
74
- num_layers=3,
75
- dim_feedforward=32,
76
- max_seq_len=8,
77
- use_act=True,
78
- act_threshold=0.5,
79
- )
80
- for proj in model.halt_projs:
81
- nn.init.constant_(proj.weight, 0.0)
82
- nn.init.constant_(proj.bias, 10.0)
83
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
84
- _, tele = model(bits)
85
- assert len(tele["halt_probs"]) < model.num_layers
86
-
87
-
88
- def test_hil_safety_gate():
89
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
90
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
91
- # Expect gate triggered with high floors
92
- raised = False
93
- try:
94
- hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0)
95
- except RuntimeError:
96
- raised = True
97
- assert raised
98
-
99
-
100
- def test_hil_safety_non_strict():
101
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
102
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
103
- out, _ = hil_safe_inference(model, bits, c_floor=1.0, s_floor=1.0, strict=False)
104
- assert out.shape == bits.shape
105
-
106
-
107
- def test_safety_gate_burn_in():
108
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
109
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
110
- gate = SafetyGate(c_floor=1.0, s_floor=1.0, burn_in=1)
111
- hil_safe_inference(model, bits, gate=gate)
112
- with pytest.raises(RuntimeError):
113
- hil_safe_inference(model, bits, gate=gate)
114
-
115
-
116
- def test_bit_io_roundtrip():
117
- text = "hello"
118
- bits = text_to_bits(text)
119
- assert bits_to_text(bits) == text
120
-
121
-
122
- def test_plot_telemetry():
123
- log = {
124
- "negentropy": [0.6, 0.7, 0.4],
125
- "lz_complexity": [0.5, 0.45, 0.6],
126
- "symbiosis_score": [0.55, 0.6, 0.3],
127
- "clusters": [0, 0, 1],
128
- }
129
- fig, axes = plot_telemetry(log)
130
- assert len(axes) == 3
131
- fig.clf()
132
-
133
-
134
- def test_metric_no_gradient_flow():
135
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
136
- bits = torch.randint(0, 2, (2, 8), dtype=torch.long)
137
- logits, _ = model(bits)
138
- loss = model.negentropy_logits(logits).mean() + model.lz_complexity_logits(logits).mean()
139
- assert not loss.requires_grad
140
- with pytest.raises(RuntimeError):
141
- loss.backward()
142
-
143
-
144
- def test_negentropy_decompression_edge_case():
145
- bits = torch.tensor([0, 1] * 8, dtype=torch.uint8)
146
- comp = compress_bits(bits)
147
- model = BitTransformerLM(d_model=16, nhead=2, num_layers=1, dim_feedforward=32, max_seq_len=bits.numel())
148
- neg_comp = model.negentropy_kpi(comp.unsqueeze(0))
149
- neg_raw = model.negentropy_kpi(bits.unsqueeze(0))
150
- assert torch.allclose(neg_comp, neg_raw, atol=1e-6)
151
-
152
-
153
- def test_dynamic_quantization():
154
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
155
- from bit_transformer import quantize_dynamic
156
-
157
- qmodel = quantize_dynamic(model)
158
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
159
- logits, _ = qmodel(bits)
160
- assert logits.shape == (1, 8, 2)
161
-
162
-
163
- def test_qat_fx_roundtrip():
164
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
165
- from bit_transformer import prepare_qat_fx, convert_qat_fx
166
-
167
- example_bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
168
- qat_model = prepare_qat_fx(model)
169
- qat_model.eval()
170
- qmodel = convert_qat_fx(qat_model)
171
-
172
- logits, _ = qmodel(example_bits)
173
- assert logits.shape == (1, 8, 2)
174
-
175
-
176
- def test_fsdp_wrap():
177
- import os
178
- import torch
179
- import torch.distributed as dist
180
- from bit_transformer import BitTransformerLM, wrap_fsdp
181
-
182
- if not dist.is_initialized():
183
- os.environ.setdefault("MASTER_ADDR", "localhost")
184
- os.environ.setdefault("MASTER_PORT", "29500")
185
- dist.init_process_group("gloo", rank=0, world_size=1)
186
- if not torch.cuda.is_available():
187
- pytest.skip("CUDA not available")
188
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
189
- fsdp_model = wrap_fsdp(model)
190
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
191
- logits, _ = fsdp_model(bits)
192
- assert logits.shape == (1, 8, 2)
193
- dist.destroy_process_group()
194
-
195
-
196
- def test_make_pipeline():
197
- import pytest
198
- import torch.distributed.rpc as rpc
199
- from bit_transformer import BitTransformerLM, make_pipeline
200
-
201
- if not rpc._is_current_rpc_agent_set():
202
- pytest.skip("RPC not initialized")
203
-
204
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
205
- pipe_model = make_pipeline(model, chunks=1)
206
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
207
- logits, _ = pipe_model(bits)
208
- assert logits.shape == (1, 8, 2)
209
-
210
-
211
- def test_causal_attention():
212
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
213
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
214
- logits, tele = model(bits, causal=True)
215
- assert logits.shape == (1, 8, 2)
216
- attn = tele["attention_maps"][0]
217
- upper = attn.triu(1)
218
- assert torch.allclose(upper, torch.zeros_like(upper))
219
-
220
-
221
- def test_scaling_helpers():
222
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
223
- model = model.double_width()
224
- assert model.d_model == 64
225
- model = model.double_layers()
226
- assert model.num_layers == 2
227
-
228
-
229
- def test_expand_positional_encoding():
230
- model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8)
231
- model.expand_positional_encoding(16)
232
- assert model.pos_enc.pe.size(0) == 16
233
-
234
-
235
- def test_infer_long_sequence():
236
- model = BitTransformerLM(d_model=16, nhead=4, num_layers=1, dim_feedforward=32, max_seq_len=8)
237
- bits = torch.randint(0, 2, (12,), dtype=torch.long)
238
- preds, logs = infer_long_sequence(model, bits, ctx_bits=8, overlap=4)
239
- assert len(preds) == 12
240
- assert len(logs) >= 2
241
-
242
-
243
- def test_chunking_disabled_when_non_causal():
244
- model = BitTransformerLM(
245
- d_model=32,
246
- nhead=4,
247
- num_layers=1,
248
- dim_feedforward=64,
249
- max_seq_len=8,
250
- chunk_size=2,
251
- full_attn_logging=True,
252
- )
253
- # Zero query/key/value projections so attention is uniformly distributed.
254
- # This makes the test deterministic: any non-masked position receives equal
255
- # weight, allowing us to rely solely on the chunking mask for the check.
256
- nn.init.constant_(model.layers[0].self_attn.in_proj_weight, 0.0)
257
- nn.init.constant_(model.layers[0].self_attn.in_proj_bias, 0.0)
258
- # Disable dropout for deterministic attention weights.
259
- model.eval()
260
- for module in model.modules():
261
- if isinstance(module, nn.Dropout):
262
- module.p = 0.0
263
- model.layers[0].self_attn.dropout = 0.0
264
-
265
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
266
- _, tele_causal = model(bits, causal=True)
267
- _, tele_noncausal = model(bits, causal=False)
268
- attn_causal = tele_causal["attention_maps"][0]
269
- attn_noncausal = tele_noncausal["attention_maps"][0]
270
- # Causal mode keeps attention within chunk boundaries, while non-causal mode
271
- # should permit cross-chunk attention.
272
- assert attn_causal[0, 0, 0, 4] == 0
273
- assert attn_noncausal[0, 0, 0, 4] > 0
274
-
275
-
276
- def test_diffusion_inference_generates_bits():
277
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
278
- out = diffusion_inference(model, length=8, steps=2, batch_size=2)
279
- assert out.shape == (2, 8)
280
- assert set(out.unique().tolist()).issubset({0, 1})
281
-
282
-
283
- def test_diffusion_inference_cosine_schedule():
284
- model = BitTransformerLM(d_model=32, nhead=4, num_layers=1, dim_feedforward=64, max_seq_len=8)
285
- out = diffusion_inference(model, length=8, steps=2, schedule="cosine")
286
- assert out.shape == (1, 8)
287
-
288
-
289
- def test_chunking_restored_after_diffusion():
290
- model = BitTransformerLM(
291
- d_model=32,
292
- nhead=4,
293
- num_layers=1,
294
- dim_feedforward=64,
295
- max_seq_len=8,
296
- chunk_size=2,
297
- full_attn_logging=True,
298
- )
299
- bits = torch.randint(0, 2, (1, 8), dtype=torch.long)
300
- _ = model(bits, causal=False)
301
- assert model.layers[0].chunk_size == 2
302
- _, tele = model(bits, causal=True)
303
- attn = tele["attention_maps"][0]
304
- assert attn[0, 0, 0, 4] == 0