| import unittest |
| import torch |
| import sys |
| import os |
|
|
| |
| sys.path.append(os.path.abspath(os.path.join(os.path.dirname(__file__), '..'))) |
|
|
| from src.config import ModelConfig |
| from src.models.autoencoder import LatentAutoencoder |
| from src.models.dit import FlowDiT |
|
|
| class TestModels(unittest.TestCase): |
| def setUp(self): |
| |
| self.cfg = ModelConfig( |
| encoder_name="roberta-base", |
| latent_dim=128, |
| max_seq_len=32, |
| decoder_layers=2, |
| dit_layers=2 |
| ) |
| |
| |
|
|
| def test_ae_shape(self): |
| print("\nTesting Autoencoder Shape...") |
| model = LatentAutoencoder(self.cfg) |
| input_ids = torch.randint(0, 100, (2, 32)) |
| mask = torch.ones((2, 32)) |
| logits, z = model(input_ids, mask) |
| |
| self.assertEqual(z.shape, (2, 32, 128)) |
| |
| self.assertEqual(logits.shape, (2, 32, 50265)) |
| print("AE Shape Check Passed.") |
|
|
| def test_dit_shape(self): |
| print("\nTesting DiT Shape...") |
| model = FlowDiT(self.cfg) |
| x = torch.randn(2, 32, 128) |
| t = torch.rand(2) |
| cond = torch.randn(2, 32, 128) |
| |
| out = model(x, t, condition=cond) |
| self.assertEqual(out.shape, (2, 32, 128)) |
| print("DiT Shape Check Passed.") |
|
|
| def test_cfg_forward(self): |
| print("\nTesting CFG Forward...") |
| model = FlowDiT(self.cfg) |
| x = torch.randn(2, 32, 128) |
| t = torch.rand(2) |
| cond = torch.randn(2, 32, 128) |
| |
| out = model.forward_with_cfg(x, t, cond, cfg_scale=3.0) |
| self.assertEqual(out.shape, (2, 32, 128)) |
| print("CFG Check Passed.") |
|
|
| if __name__ == "__main__": |
| unittest.main() |