| |
| """Run all notebook cells as a script to verify everything works.""" |
| import sys, os |
| sys.path.insert(0, '/app') |
| os.environ['MPLBACKEND'] = 'Agg' |
|
|
| import torch |
| import numpy as np |
|
|
| device = torch.device('cpu') |
| print(f'Using device: {device}') |
|
|
| |
| print("\n=== Cell 1: Architecture Overview ===") |
| from lrf.model import LatentRecurrentFlow, RecursiveLatentCore, CompactVAE, GatedLinearAttention |
| from lrf.training import LRFTrainer, RectifiedFlowScheduler, SyntheticImageTextDataset |
| from lrf.pipeline import LRFPipeline, LRFTrainingPipeline |
|
|
| configs = { |
| 'Tiny (5.7M)': LatentRecurrentFlow.tiny_config(), |
| 'Default (16.3M)': LatentRecurrentFlow.default_config(), |
| } |
|
|
| for name, config in configs.items(): |
| model = LatentRecurrentFlow(config) |
| counts = model.count_parameters() |
| print(f'\n{name}:') |
| for module, count in counts.items(): |
| print(f' {module:20s}: {count:>12,}') |
| del model |
|
|
| |
| print("\n=== Cell 2: VAE Training ===") |
| config = LatentRecurrentFlow.tiny_config() |
| model = LatentRecurrentFlow(config).to(device) |
| from torch.utils.data import DataLoader |
|
|
| dataset = SyntheticImageTextDataset(num_samples=100, image_size=64, max_text_length=32) |
| dataloader = DataLoader(dataset, batch_size=8, shuffle=True, num_workers=0) |
|
|
| trainer = LRFTrainer(model, device, '/app/nb_checkpoints') |
|
|
| vae_optimizer = torch.optim.AdamW(model.vae.parameters(), lr=1e-3, weight_decay=0.01) |
| for i, batch in enumerate(dataloader): |
| if i >= 10: |
| break |
| losses = trainer.train_vae_step(batch['image'], vae_optimizer) |
| if i % 5 == 0: |
| print(f' VAE step {i}: loss={losses["total"]:.4f}') |
|
|
| trainer.save_checkpoint('/app/nb_checkpoints/vae.pt', 'vae', 0) |
|
|
| |
| model.eval() |
| with torch.no_grad(): |
| sample_batch = next(iter(dataloader)) |
| images = sample_batch['image'].to(device) |
| recon, _, _ = model.vae(images) |
| print(f' Reconstruction MSE: {((recon - images)**2).mean():.4f}') |
|
|
| |
| print("\n=== Cell 3: Flow Matching Training ===") |
| for p in model.vae.parameters(): |
| p.requires_grad = False |
|
|
| flow_params = list(model.core.parameters()) + list(model.text_encoder.parameters()) |
| flow_optimizer = torch.optim.AdamW(flow_params, lr=1e-3, weight_decay=0.01) |
|
|
| model.core.train() |
| model.text_encoder.train() |
|
|
| for i, batch in enumerate(dataloader): |
| if i >= 10: |
| break |
| losses = trainer.train_flow_step( |
| batch['image'], batch['token_ids'], batch['attention_mask'], |
| flow_optimizer, cfg_dropout=0.1 |
| ) |
| if i % 5 == 0: |
| print(f' Flow step {i}: loss={losses["flow_loss"]:.4f}') |
|
|
| trainer.save_checkpoint('/app/nb_checkpoints/flow.pt', 'flow', 0) |
|
|
| |
| print("\n=== Cell 4: Generation ===") |
| model.eval() |
| pipe = LRFPipeline(model, device=device) |
|
|
| prompts = ['a sunset', 'a cat', 'mountains', 'abstract art'] |
| images = pipe(prompts, num_steps=5, cfg_scale=1.0, height=64, width=64, seed=42) |
| print(f' Generated {images.shape[0]} images: {images.shape}') |
| print(f' Range: [{images.min():.3f}, {images.max():.3f}]') |
|
|
| |
| print("\n=== Cell 5: Save & Load ===") |
| pipe.save_pretrained('/app/nb_model') |
| print(' Model saved to /app/nb_model/') |
| for f in os.listdir('/app/nb_model'): |
| size = os.path.getsize(f'/app/nb_model/{f}') |
| print(f' {f}: {size/1024:.1f} KB') |
|
|
| pipe_loaded = LRFPipeline.from_pretrained('/app/nb_model', device=str(device)) |
| images_loaded = pipe_loaded('test prompt', num_steps=5, height=64, width=64, seed=42) |
| print(f' Reloaded model generates: {images_loaded.shape}') |
|
|
| |
| print("\n=== Cell 6: Training Curriculum ===") |
| curriculum = LRFTrainingPipeline.get_curriculum() |
| for i, stage_name in enumerate(curriculum): |
| stage = LRFTrainingPipeline.get_stage_config(stage_name) |
| print(f' Stage {i+1}: {stage_name} - {stage["description"]}') |
|
|
| |
| print("\n=== Cell 7: Core Architecture ===") |
| core = RecursiveLatentCore( |
| dim=32, cond_dim=64, num_blocks=2, num_heads=2, head_dim=16, |
| T_inner=4, T_outer=2, use_ift_training=False |
| ) |
| print(f' Effective depth: {core.T_outer * core.T_inner * core.num_blocks} layers') |
| print(f' Parameters: {sum(p.numel() for p in core.parameters()):,}') |
|
|
| |
| print("\n=== Cell 8: GLA Scaling ===") |
| import time |
| gla = GatedLinearAttention(dim=64, num_heads=4, head_dim=16) |
| for s in [4, 8, 16, 32]: |
| x = torch.randn(1, s*s, 64) |
| _ = gla(x, h=s, w=s) |
| t0 = time.time() |
| for _ in range(5): |
| _ = gla(x, h=s, w=s) |
| dt = (time.time() - t0) / 5 |
| print(f' {s}×{s} = {s*s:>5} tokens: {dt*1000:.2f}ms') |
|
|
| print("\n✅ ALL NOTEBOOK CELLS VERIFIED!") |
|
|