krystv commited on
Commit
fee179c
·
verified ·
1 Parent(s): ad55fd7

Upload tests/test_lrf.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. tests/test_lrf.py +402 -0
tests/test_lrf.py ADDED
@@ -0,0 +1,402 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ ============================================================================
4
+ LatentRecurrentFlow (LRF) — End-to-End Test Script
5
+ ============================================================================
6
+
7
+ Tests the full pipeline on CPU:
8
+ 1. Model creation and parameter counting
9
+ 2. VAE forward pass
10
+ 3. Flow matching forward pass
11
+ 4. Recursive latent core forward pass
12
+ 5. Full training loop (few steps)
13
+ 6. Sample generation
14
+ 7. Checkpoint save/load
15
+
16
+ Run: python test_lrf.py
17
+ """
18
+
19
+ import sys
20
+ import os
21
+ import time
22
+ import torch
23
+ import traceback
24
+
25
+ # Add project root
26
+ sys.path.insert(0, '/app')
27
+
28
+ def test_model_creation():
29
+ """Test model creation with different configs."""
30
+ print("\n[TEST 1] Model Creation")
31
+ print("-" * 40)
32
+
33
+ from lrf.model import LatentRecurrentFlow
34
+
35
+ # Test tiny config
36
+ model = LatentRecurrentFlow(LatentRecurrentFlow.tiny_config())
37
+ counts = model.count_parameters()
38
+ print("Tiny config parameters:")
39
+ for name, count in counts.items():
40
+ print(f" {name}: {count:,}")
41
+ assert counts['total'] > 0, "Model has no parameters!"
42
+
43
+ # Test default config
44
+ model_default = LatentRecurrentFlow(LatentRecurrentFlow.default_config())
45
+ counts_default = model_default.count_parameters()
46
+ print("\nDefault config parameters:")
47
+ for name, count in counts_default.items():
48
+ print(f" {name}: {count:,}")
49
+ assert counts_default['total'] > counts['total'], "Default should be larger than tiny"
50
+
51
+ print("✓ Model creation passed")
52
+ return True
53
+
54
+
55
+ def test_vae():
56
+ """Test VAE forward and backward."""
57
+ print("\n[TEST 2] VAE Forward/Backward")
58
+ print("-" * 40)
59
+
60
+ from lrf.model import CompactVAE
61
+
62
+ vae = CompactVAE(in_channels=3, latent_channels=16, encoder_base_ch=32, decoder_base_ch=64)
63
+
64
+ # Count params
65
+ enc_params = sum(p.numel() for p in vae.encoder.parameters())
66
+ dec_params = sum(p.numel() for p in vae.decoder.parameters())
67
+ print(f"Encoder params: {enc_params:,}")
68
+ print(f"Decoder params: {dec_params:,}")
69
+
70
+ # Forward
71
+ x = torch.randn(2, 3, 64, 64)
72
+ recon, mean, logvar = vae(x)
73
+ print(f"Input shape: {x.shape}")
74
+ print(f"Latent shape: {mean.shape}")
75
+ print(f"Recon shape: {recon.shape}")
76
+
77
+ assert recon.shape == x.shape, f"Reconstruction shape mismatch: {recon.shape} != {x.shape}"
78
+ assert mean.shape[1] == 16, f"Latent channels mismatch: {mean.shape[1]}"
79
+
80
+ # Backward
81
+ loss = F.l1_loss(recon, x) - 0.5 * torch.mean(1 + logvar - mean.pow(2) - logvar.exp()) * 1e-6
82
+ loss.backward()
83
+
84
+ # Check gradients
85
+ grad_ok = all(p.grad is not None for p in vae.parameters() if p.requires_grad)
86
+ print(f"Gradients computed: {grad_ok}")
87
+
88
+ print("✓ VAE test passed")
89
+ return True
90
+
91
+
92
+ def test_gla():
93
+ """Test Gated Linear Attention."""
94
+ print("\n[TEST 3] Gated Linear Attention")
95
+ print("-" * 40)
96
+
97
+ from lrf.model import GatedLinearAttention
98
+
99
+ gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16)
100
+
101
+ B, H, W, D = 2, 8, 8, 64
102
+ x = torch.randn(B, H * W, D)
103
+
104
+ t0 = time.time()
105
+ out = gla(x, h=H, w=W)
106
+ dt = time.time() - t0
107
+
108
+ print(f"Input: {x.shape}")
109
+ print(f"Output: {out.shape}")
110
+ print(f"Time: {dt*1000:.1f}ms")
111
+
112
+ assert out.shape == x.shape, f"Shape mismatch: {out.shape}"
113
+
114
+ # Test with larger sequence
115
+ B, H, W, D = 1, 32, 32, 64
116
+ x_large = torch.randn(B, H * W, D)
117
+ t0 = time.time()
118
+ out_large = gla(x_large, h=H, w=W)
119
+ dt_large = time.time() - t0
120
+ print(f"\nLarger input (32x32={H*W} tokens):")
121
+ print(f" Time: {dt_large*1000:.1f}ms")
122
+
123
+ print("✓ GLA test passed")
124
+ return True
125
+
126
+
127
+ def test_recursive_core():
128
+ """Test the Recursive Latent Core."""
129
+ print("\n[TEST 4] Recursive Latent Core")
130
+ print("-" * 40)
131
+
132
+ from lrf.model import RecursiveLatentCore
133
+
134
+ core = RecursiveLatentCore(
135
+ dim=32,
136
+ cond_dim=64,
137
+ num_blocks=2,
138
+ num_heads=2,
139
+ head_dim=16,
140
+ T_inner=2,
141
+ T_outer=1,
142
+ use_ift_training=False,
143
+ )
144
+
145
+ params = sum(p.numel() for p in core.parameters())
146
+ print(f"Core params: {params:,}")
147
+
148
+ B, C, H, W = 2, 32, 4, 4
149
+ z_t = torch.randn(B, C, H, W)
150
+ t = torch.rand(B)
151
+ text_emb = torch.randn(B, 10, 64)
152
+ text_global = torch.randn(B, 64)
153
+
154
+ # Forward
155
+ t0 = time.time()
156
+ v = core(z_t, t, text_emb, text_global)
157
+ dt = time.time() - t0
158
+
159
+ print(f"Input shape: {z_t.shape}")
160
+ print(f"Output shape: {v.shape}")
161
+ print(f"Time: {dt*1000:.1f}ms")
162
+
163
+ assert v.shape == z_t.shape, f"Shape mismatch: {v.shape}"
164
+
165
+ # Backward
166
+ loss = v.pow(2).mean()
167
+ loss.backward()
168
+
169
+ grad_ok = sum(1 for p in core.parameters() if p.grad is not None and p.requires_grad)
170
+ total_params = sum(1 for p in core.parameters() if p.requires_grad)
171
+ print(f"Params with grad: {grad_ok}/{total_params}")
172
+
173
+ print("✓ Recursive core test passed")
174
+ return True
175
+
176
+
177
+ def test_ift_training():
178
+ """Test IFT (Implicit Function Theorem) training mode."""
179
+ print("\n[TEST 5] IFT Training Mode")
180
+ print("-" * 40)
181
+
182
+ from lrf.model import RecursiveLatentCore
183
+
184
+ # Test with IFT enabled
185
+ core_ift = RecursiveLatentCore(
186
+ dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16,
187
+ T_inner=3, T_outer=2, use_ift_training=True,
188
+ )
189
+ core_ift.train()
190
+
191
+ z_t = torch.randn(2, 32, 4, 4, requires_grad=True)
192
+ t = torch.rand(2)
193
+
194
+ v = core_ift(z_t, t)
195
+ loss = v.pow(2).mean()
196
+ loss.backward()
197
+
198
+ print(f"IFT mode: loss={loss.item():.4f}")
199
+ print(f" T_outer={core_ift.T_outer}, T_inner={core_ift.T_inner}")
200
+ print(f" Effective depth: {core_ift.T_outer * core_ift.T_inner * core_ift.num_blocks} layers")
201
+ print(f" Actual blocks: {core_ift.num_blocks}")
202
+
203
+ print("✓ IFT training test passed")
204
+ return True
205
+
206
+
207
+ def test_flow_matching():
208
+ """Test flow matching scheduler."""
209
+ print("\n[TEST 6] Flow Matching Scheduler")
210
+ print("-" * 40)
211
+
212
+ from lrf.training import RectifiedFlowScheduler
213
+
214
+ scheduler = RectifiedFlowScheduler(shift=1.0)
215
+
216
+ z_0 = torch.randn(2, 16, 4, 4)
217
+ noise = torch.randn_like(z_0)
218
+ t = torch.tensor([0.0, 0.5])
219
+
220
+ z_t = scheduler.add_noise(z_0, noise, t)
221
+ v_target = scheduler.get_velocity_target(z_0, noise)
222
+
223
+ print(f"z_0 shape: {z_0.shape}")
224
+ print(f"z_t shape: {z_t.shape}")
225
+ print(f"v_target shape: {v_target.shape}")
226
+
227
+ # At t=0, z_t should equal z_0
228
+ t_zero = torch.tensor([0.0, 0.0])
229
+ z_t_zero = scheduler.add_noise(z_0, noise, t_zero)
230
+ diff = (z_t_zero - z_0).abs().max().item()
231
+ print(f"At t=0, |z_t - z_0| max = {diff:.6f}")
232
+ assert diff < 1e-5, f"At t=0, z_t should equal z_0, got diff={diff}"
233
+
234
+ # At t=1, z_t should equal noise
235
+ t_one = torch.tensor([1.0, 1.0])
236
+ z_t_one = scheduler.add_noise(z_0, noise, t_one)
237
+ diff_one = (z_t_one - noise).abs().max().item()
238
+ print(f"At t=1, |z_t - noise| max = {diff_one:.6f}")
239
+ assert diff_one < 1e-5, f"At t=1, z_t should equal noise, got diff={diff_one}"
240
+
241
+ print("✓ Flow matching test passed")
242
+ return True
243
+
244
+
245
+ def test_full_training():
246
+ """Test full training pipeline."""
247
+ print("\n[TEST 7] Full Training Pipeline")
248
+ print("-" * 40)
249
+
250
+ from lrf.model import LatentRecurrentFlow
251
+ from lrf.training import LRFTrainer, SyntheticImageTextDataset
252
+ from torch.utils.data import DataLoader
253
+
254
+ config = LatentRecurrentFlow.tiny_config()
255
+ model = LatentRecurrentFlow(config)
256
+
257
+ trainer = LRFTrainer(model, torch.device('cpu'), '/app/test_checkpoints')
258
+
259
+ dataset = SyntheticImageTextDataset(num_samples=16, image_size=64, max_text_length=32)
260
+ dataloader = DataLoader(dataset, batch_size=4, shuffle=True)
261
+
262
+ # Stage 1: VAE
263
+ print(" Training VAE...")
264
+ vae_opt = torch.optim.AdamW(model.vae.parameters(), lr=1e-3)
265
+ for i, batch in enumerate(dataloader):
266
+ if i >= 3:
267
+ break
268
+ losses = trainer.train_vae_step(batch['image'], vae_opt)
269
+ print(f" VAE step {i}: loss={losses['total']:.4f}")
270
+
271
+ # Stage 2: Flow matching
272
+ print(" Training flow matching...")
273
+ for p in model.vae.parameters():
274
+ p.requires_grad = False
275
+
276
+ flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters())
277
+ flow_opt = torch.optim.AdamW(flow_params, lr=1e-3)
278
+
279
+ for i, batch in enumerate(dataloader):
280
+ if i >= 3:
281
+ break
282
+ losses = trainer.train_flow_step(
283
+ batch['image'], batch['token_ids'], batch['attention_mask'],
284
+ flow_opt,
285
+ )
286
+ print(f" Flow step {i}: loss={losses['flow_loss']:.4f}")
287
+
288
+ # Generate
289
+ print(" Generating samples...")
290
+ sample_tokens = torch.randint(1, 31999, (2, 32))
291
+ sample_mask = torch.ones(2, 32)
292
+
293
+ images = trainer.generate(
294
+ sample_tokens, sample_mask,
295
+ num_steps=5, cfg_scale=1.0,
296
+ latent_h=4, latent_w=4,
297
+ )
298
+ print(f" Generated: {images.shape}, range=[{images.min():.3f}, {images.max():.3f}]")
299
+
300
+ # Save/load checkpoint
301
+ print(" Saving checkpoint...")
302
+ trainer.save_checkpoint('/app/test_checkpoints/test.pt', 'test', 0)
303
+ trainer.load_checkpoint('/app/test_checkpoints/test.pt')
304
+
305
+ print("✓ Full training pipeline test passed")
306
+ return True
307
+
308
+
309
+ def test_memory_estimate():
310
+ """Estimate memory usage for different configs."""
311
+ print("\n[TEST 8] Memory Estimation")
312
+ print("-" * 40)
313
+
314
+ from lrf.model import LatentRecurrentFlow
315
+
316
+ configs = {
317
+ 'tiny': LatentRecurrentFlow.tiny_config(),
318
+ 'default': LatentRecurrentFlow.default_config(),
319
+ }
320
+
321
+ for name, config in configs.items():
322
+ model = LatentRecurrentFlow(config)
323
+ counts = model.count_parameters()
324
+
325
+ # Estimate memory
326
+ param_bytes = counts['total'] * 4 # float32
327
+ param_mb = param_bytes / (1024 * 1024)
328
+
329
+ # INT8 deployment
330
+ param_int8_mb = counts['total'] * 1 / (1024 * 1024)
331
+
332
+ print(f"\n{name} config:")
333
+ print(f" Total params: {counts['total']:,}")
334
+ print(f" FP32 size: {param_mb:.1f} MB")
335
+ print(f" INT8 size: {param_int8_mb:.1f} MB")
336
+
337
+ # Estimate activation memory for 256x256 generation
338
+ latent_h = 256 // 16
339
+ latent_w = 256 // 16
340
+ latent_tokens = latent_h * latent_w
341
+ act_bytes = 2 * latent_tokens * config['latent_channels'] * 4 # Conservative
342
+ act_mb = act_bytes / (1024 * 1024)
343
+ print(f" Est. activation memory (256x256): {act_mb:.1f} MB")
344
+
345
+ del model
346
+
347
+ print("\n✓ Memory estimation passed")
348
+ return True
349
+
350
+
351
+ # Import F for backward test
352
+ import torch.nn.functional as F
353
+
354
+ def main():
355
+ """Run all tests."""
356
+ print("=" * 60)
357
+ print("LatentRecurrentFlow (LRF) - End-to-End Tests")
358
+ print("=" * 60)
359
+
360
+ tests = [
361
+ ("Model Creation", test_model_creation),
362
+ ("VAE", test_vae),
363
+ ("GLA", test_gla),
364
+ ("Recursive Core", test_recursive_core),
365
+ ("IFT Training", test_ift_training),
366
+ ("Flow Matching", test_flow_matching),
367
+ ("Full Training", test_full_training),
368
+ ("Memory Estimate", test_memory_estimate),
369
+ ]
370
+
371
+ results = []
372
+ for name, test_fn in tests:
373
+ try:
374
+ passed = test_fn()
375
+ results.append((name, passed))
376
+ except Exception as e:
377
+ print(f"\n✗ {name} FAILED: {e}")
378
+ traceback.print_exc()
379
+ results.append((name, False))
380
+
381
+ print("\n" + "=" * 60)
382
+ print("Test Summary")
383
+ print("=" * 60)
384
+
385
+ all_passed = True
386
+ for name, passed in results:
387
+ status = "✓ PASS" if passed else "✗ FAIL"
388
+ print(f" {status}: {name}")
389
+ if not passed:
390
+ all_passed = False
391
+
392
+ if all_passed:
393
+ print("\n✓ ALL TESTS PASSED!")
394
+ else:
395
+ print("\n✗ SOME TESTS FAILED!")
396
+ sys.exit(1)
397
+
398
+ return all_passed
399
+
400
+
401
+ if __name__ == '__main__':
402
+ main()