microforge / test_microforge.py
asdf98's picture
Add test_microforge.py
9bfb518 verified
#!/usr/bin/env python3
"""
MicroForge End-to-End Test Suite
Validates all modules work correctly on CPU.
"""
import torch
import time
import sys
import os
# Add parent to path
sys.path.insert(0, os.path.dirname(os.path.abspath(__file__)))
def test_vae():
"""Test all VAE configurations."""
from microforge.vae import MicroForgeVAE
print("=" * 60)
print("TEST: MicroForge VAE")
print("=" * 60)
for config in ['tiny', 'small', 'base']:
vae = MicroForgeVAE(config=config)
params = sum(p.numel() for p in vae.parameters())
# Test forward pass
x = torch.randn(1, 3, 256, 256)
x_recon, mu, logvar = vae(x)
assert x_recon.shape == x.shape, f"Recon shape mismatch: {x_recon.shape} vs {x.shape}"
assert not torch.isnan(mu).any(), "NaN in mu"
assert not torch.isnan(logvar).any(), "NaN in logvar"
# Test encode/decode
z = vae.get_latent(x)
x_dec = vae.decode(z)
assert x_dec.shape == x.shape
# Test KL loss
kl = MicroForgeVAE.kl_loss(mu, logvar)
assert not torch.isnan(kl), "NaN in KL loss"
print(f" [{config:>5}] PASS | params={params:,} | latent={mu.shape} | KL={kl.item():.2f}")
print()
def test_backbone():
"""Test all backbone configurations."""
from microforge.backbone import MicroForgeBackbone
print("=" * 60)
print("TEST: MicroForge Backbone")
print("=" * 60)
for config in ['tiny', 'small', 'base']:
lc = 16 if config == 'tiny' else 32
backbone = MicroForgeBackbone(latent_channels=lc, config=config)
params = sum(p.numel() for p in backbone.parameters())
z = torch.randn(1, lc, 8, 8)
t = torch.rand(1)
text_emb = torch.randn(1, 10, 768)
text_pooled = torch.randn(1, 768)
start = time.time()
v = backbone(z, t, text_emb, text_pooled)
elapsed = (time.time() - start) * 1000
assert v.shape == z.shape, f"Output shape mismatch: {v.shape} vs {z.shape}"
assert not torch.isnan(v).any(), "NaN in velocity prediction"
print(f" [{config:>5}] PASS | params={params:,} | latency={elapsed:.0f}ms")
print()
def test_planner():
"""Test Recurrent Latent Planner."""
from microforge.planner import RecurrentLatentPlanner
print("=" * 60)
print("TEST: Recurrent Latent Planner")
print("=" * 60)
planner = RecurrentLatentPlanner(
num_plan_tokens=32, dim=384, text_dim=768, latent_channels=32
)
params = sum(p.numel() for p in planner.parameters())
# Test initialization
text_pooled = torch.randn(2, 768)
plan = planner.initialize_plan(text_pooled, batch_size=2)
assert plan.shape == (2, 32, 384), f"Plan shape: {plan.shape}"
# Test forward
img_tokens = torch.randn(2, 64, 32) # 8x8 latent flattened
t_emb = torch.randn(2, 384)
plan_out, output = planner(img_tokens, plan, t_emb)
assert plan_out.shape == (2, 32, 384)
assert output.shape == (2, 32, 768) # Projected to text_dim
assert not torch.isnan(plan_out).any()
assert not torch.isnan(output).any()
# Test self-conditioning
plan_next = planner.initialize_plan(text_pooled, 2, prev_plan=plan_out)
assert plan_next.shape == plan.shape
print(f" PASS | params={params:,} | plan_state={planner.get_plan_size_bytes()} bytes")
print()
def test_training():
"""Test training loop."""
from microforge.vae import MicroForgeVAE
from microforge.backbone import MicroForgeBackbone
from microforge.planner import RecurrentLatentPlanner
from microforge.training import MicroForgeTrainer, FlowMatchingScheduler
print("=" * 60)
print("TEST: Training Pipeline")
print("=" * 60)
vae = MicroForgeVAE(config='tiny').eval()
backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
trainer = MicroForgeTrainer(vae, backbone, planner, lr=1e-4, use_ema=True)
# Test flow matching scheduler
scheduler = FlowMatchingScheduler()
t = scheduler.sample_timesteps(4, torch.device('cpu'))
assert t.min() >= 0 and t.max() <= 1, f"Timesteps out of range: {t}"
z_0 = torch.randn(4, 16, 4, 4)
noise = torch.randn_like(z_0)
z_t, v_target = scheduler.add_noise(z_0, noise, t)
assert z_t.shape == z_0.shape
assert v_target.shape == z_0.shape
# Test training steps
images = torch.randn(2, 3, 128, 128)
text_emb = torch.randn(2, 10, 768)
text_pooled = torch.randn(2, 768)
losses = []
for i in range(5):
step_losses = trainer.train_step(images, text_emb, text_pooled)
losses.append(step_losses['flow'])
assert not any(torch.isnan(torch.tensor(v)) for v in step_losses.values()), \
f"NaN in losses: {step_losses}"
print(f" 5 training steps: loss {losses[0]:.2f} -> {losses[-1]:.2f}")
print(f" PASS")
print()
def test_pipeline():
"""Test end-to-end inference pipeline."""
from microforge.vae import MicroForgeVAE
from microforge.backbone import MicroForgeBackbone
from microforge.planner import RecurrentLatentPlanner
from microforge.pipeline import MicroForgePipeline, SimpleTextEncoder
print("=" * 60)
print("TEST: End-to-End Pipeline")
print("=" * 60)
vae = MicroForgeVAE(config='tiny')
backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
planner = RecurrentLatentPlanner(num_plan_tokens=16, dim=256, text_dim=768, latent_channels=16)
text_enc = SimpleTextEncoder(embed_dim=768, num_layers=2)
pipeline = MicroForgePipeline(vae, backbone, text_enc, planner, device='cpu')
# Test text2img
tokens = torch.randint(0, 8192, (1, 10))
start = time.time()
images = pipeline.text2img(tokens, height=128, width=128, num_steps=2, cfg_scale=1.0, seed=42)
t2i_time = time.time() - start
assert images.shape == (1, 3, 128, 128), f"Wrong output shape: {images.shape}"
assert images.min() >= -1 and images.max() <= 1, f"Range error: [{images.min()}, {images.max()}]"
print(f" text2img: {images.shape} in {t2i_time:.2f}s | PASS")
# Test parameter count
params = pipeline.count_parameters()
print(f" Total params: {params['total']:,}")
# Test memory estimate
mem = pipeline.get_memory_estimate(512, 512)
print(f" Est. memory @512px: {mem['estimated_inference_mb']:.0f} MB")
print(f" PASS")
print()
def test_editing_pathway():
"""Test that editing pathway works (spatial concat)."""
from microforge.backbone import MicroForgeBackbone
print("=" * 60)
print("TEST: Editing Pathway (Spatial Concat)")
print("=" * 60)
backbone = MicroForgeBackbone(latent_channels=16, config='tiny')
# Standard generation: 8x8 latent
z_gen = torch.randn(1, 16, 8, 8)
t = torch.rand(1)
text_emb = torch.randn(1, 5, 768)
text_pooled = torch.randn(1, 768)
v_gen = backbone(z_gen, t, text_emb, text_pooled)
assert v_gen.shape == z_gen.shape, f"Gen output shape: {v_gen.shape}"
# Editing: 8x16 latent (width-concat target + source)
z_edit = torch.randn(1, 16, 8, 16) # Doubled width
v_edit = backbone(z_edit, t, text_emb, text_pooled)
assert v_edit.shape == z_edit.shape, f"Edit output shape: {v_edit.shape}"
# Extract target velocity (left half)
v_target = v_edit[..., :8]
assert v_target.shape == z_gen.shape
print(f" Generation: {z_gen.shape} -> {v_gen.shape} | PASS")
print(f" Editing: {z_edit.shape} -> {v_edit.shape} | PASS")
print()
def main():
print()
print("🔨 MicroForge Architecture Test Suite")
print("=" * 60)
print()
test_vae()
test_backbone()
test_planner()
test_training()
test_pipeline()
test_editing_pathway()
print("=" * 60)
print("✅ ALL TESTS PASSED")
print("=" * 60)
if __name__ == "__main__":
main()