| import math |
|
|
| import torch |
|
|
| from src.transformer_video import WanDiscreteVideoTransformer |
|
|
|
|
| def _available_device(): |
| return "cuda" if torch.cuda.is_available() else "cpu" |
|
|
|
|
| def test_wan_discrete_video_transformer_forward_and_shapes(): |
| """ |
| Basic smoke test: |
| - build a tiny WanDiscreteVideoTransformer |
| - run a forward pass with random pseudo-video tokens + random text |
| - check output shapes, parameter count and (if CUDA present) memory usage |
| """ |
|
|
| device = _available_device() |
|
|
| |
| codebook_size = 128 |
| vocab_size = codebook_size + 1 |
| num_frames = 2 |
| height = 16 |
| width = 16 |
|
|
| model = WanDiscreteVideoTransformer( |
| codebook_size=codebook_size, |
| vocab_size=vocab_size, |
| num_frames=num_frames, |
| height=height, |
| width=width, |
| |
| in_dim=32, |
| dim=64, |
| ffn_dim=128, |
| freq_dim=32, |
| text_dim=64, |
| out_dim=32, |
| num_heads=4, |
| num_layers=2, |
| ).to(device) |
| model.eval() |
|
|
| batch_size = 2 |
|
|
| |
| tokens = torch.randint( |
| low=0, |
| high=codebook_size, |
| size=(batch_size, num_frames, height, width), |
| dtype=torch.long, |
| device=device, |
| ) |
|
|
| |
| text_seq_len = 8 |
| encoder_hidden_states = torch.randn( |
| batch_size, text_seq_len, model.backbone.text_dim, device=device |
| ) |
|
|
| |
| timesteps = torch.randint( |
| low=0, high=1000, size=(batch_size,), dtype=torch.long, device=device |
| ) |
|
|
| |
| if device == "cuda": |
| torch.cuda.reset_peak_memory_stats() |
| mem_before = torch.cuda.memory_allocated() |
| else: |
| mem_before = 0 |
|
|
| with torch.no_grad(): |
| logits = model( |
| tokens=tokens, |
| timesteps=timesteps, |
| encoder_hidden_states=encoder_hidden_states, |
| y=None, |
| ) |
|
|
| if device == "cuda": |
| mem_after = torch.cuda.memory_allocated() |
| peak_mem = torch.cuda.max_memory_allocated() |
| else: |
| mem_after = mem_before |
| peak_mem = mem_before |
|
|
| |
| assert logits.shape[0] == batch_size |
| assert logits.shape[1] == codebook_size |
| assert logits.shape[2] == num_frames |
|
|
| |
| h_out = height // model.backbone.patch_size[1] |
| w_out = width // model.backbone.patch_size[2] |
| assert logits.shape[3] == h_out |
| assert logits.shape[4] == w_out |
|
|
| |
| num_params = sum(p.numel() for p in model.parameters()) |
| assert num_params > 0 |
| assert math.isfinite(float(num_params)) |
|
|
| |
| if device == "cuda": |
| assert peak_mem >= mem_after >= mem_before |
|
|
|
|
|
|