| """Tests for AAM Diffusion Model components.""" |
|
|
| import torch |
| import pytest |
| from diffusion_llm.config.model_config import AamDiffusionConfig, get_default_config, ModelConfig |
| from diffusion_llm.model.noise_scheduler import NoiseScheduler |
| from diffusion_llm.model.graph_encoder import GraphConditioningEncoder, GraphEncoderConfig |
| from diffusion_llm.model.diffusion_transformer import DiffusionTransformer |
| from diffusion_llm.model.aam_diffusion_model import AamDiffusionModel |
| from diffusion_llm.tokenizer.aam_tokenizer import AamTokenizer |
|
|
|
|
| class TestConfig: |
| """Test configuration system.""" |
|
|
| def test_default_config(self): |
| """Test default configuration creation.""" |
| config = get_default_config("base") |
| assert config.model.d_model == 768 |
| assert config.model.n_layers == 12 |
| assert config.diffusion.n_timesteps == 1000 |
|
|
| def test_tiny_config(self): |
| """Test tiny model configuration.""" |
| config = get_default_config("tiny") |
| assert config.model.d_model == 256 |
| assert config.model.n_layers == 4 |
|
|
| def test_config_serialization(self, tmp_path): |
| """Test config save/load roundtrip.""" |
| config = get_default_config("small") |
| path = tmp_path / "config.json" |
| config.to_json(path) |
|
|
| loaded = AamDiffusionConfig.from_json(path) |
| assert loaded.model.d_model == config.model.d_model |
| assert loaded.model.n_layers == config.model.n_layers |
|
|
| def test_param_estimation(self): |
| """Test parameter count estimation.""" |
| config = ModelConfig(d_model=768, n_layers=12, d_ff=3072) |
| params = config.estimate_params() |
| assert "M" in params |
|
|
|
|
| class TestTokenizer: |
| """Test AAM Tokenizer.""" |
|
|
| def test_basic_encoding(self): |
| """Test basic text encoding.""" |
| tokenizer = AamTokenizer() |
| |
| tokenizer.train(["Hello world this is a test", "Another test sentence"]) |
|
|
| ids = tokenizer.encode("Hello world") |
| assert isinstance(ids, list) |
| assert len(ids) > 0 |
| assert ids[0] == tokenizer.bos_id |
| assert ids[-1] == tokenizer.eos_id |
|
|
| def test_decode_roundtrip(self): |
| """Test encode/decode roundtrip.""" |
| tokenizer = AamTokenizer() |
| texts = [ |
| "Berdasarkan analisis, pencuri adalah Diancang.", |
| "Anomali terdeteksi dalam laporan Hefei.", |
| "Evidence: Ju Jangmok, Snow Plum Pill.", |
| ] |
| tokenizer.train(texts) |
|
|
| for text in texts: |
| ids = tokenizer.encode(text) |
| decoded = tokenizer.decode(ids, skip_special=True) |
| |
| assert len(decoded) > 0 |
|
|
| def test_special_tokens(self): |
| """Test special token IDs.""" |
| tokenizer = AamTokenizer() |
| assert tokenizer.pad_id == 0 |
| assert tokenizer.bos_id == 1 |
| assert tokenizer.eos_id == 2 |
|
|
| def test_sentence_boundaries(self): |
| """Test sentence boundary detection.""" |
| tokenizer = AamTokenizer() |
| ids = [1, 10, 20, 5, 30, 40, 5, 50, 2] |
| boundaries = tokenizer.get_sentence_boundaries(ids) |
| assert 3 in boundaries |
| assert 6 in boundaries |
|
|
| def test_save_load(self, tmp_path): |
| """Test tokenizer save/load.""" |
| tokenizer = AamTokenizer() |
| tokenizer.train(["Test text for tokenizer", "Another training example"]) |
|
|
| path = tmp_path / "tokenizer.json" |
| tokenizer.save(path) |
|
|
| loaded = AamTokenizer.load(path) |
| assert loaded.vocab_size == tokenizer.vocab_size |
| assert loaded.is_trained |
|
|
| def test_structure_encoding(self): |
| """Test encoding with graph structure tokens.""" |
| tokenizer = AamTokenizer() |
| tokenizer.train(["Evidence text", "Anomaly description", "Reasoning step"]) |
|
|
| ids = tokenizer.encode_with_structure( |
| text="Main narrative text", |
| evidence_nodes=["evidence1", "evidence2"], |
| anomalies=["anomaly1"], |
| ) |
| assert isinstance(ids, list) |
| assert len(ids) > 0 |
|
|
| def test_padding(self): |
| """Test sequence padding.""" |
| tokenizer = AamTokenizer() |
| ids = [1, 2, 3] |
| padded = tokenizer.pad_sequence(ids, max_len=10) |
| assert len(padded) == 10 |
| assert padded[3:] == [0] * 7 |
|
|
|
|
| class TestDiffusionTransformer: |
| """Test Diffusion Transformer model.""" |
|
|
| def test_forward_pass(self): |
| """Test basic forward pass.""" |
| config = ModelConfig( |
| d_model=128, n_layers=2, n_heads=4, d_ff=256, |
| vocab_size=1000, max_seq_len=64, |
| ) |
| model = DiffusionTransformer(config) |
|
|
| x_t = torch.randn(2, 32, 128) |
| t = torch.tensor([100, 500]) |
|
|
| output = model(x_t=x_t, t=t) |
| assert output.shape == (2, 32, 128) |
|
|
| def test_with_graph_conditioning(self): |
| """Test forward pass with graph conditioning.""" |
| config = ModelConfig( |
| d_model=128, n_layers=2, n_heads=4, d_ff=256, |
| vocab_size=1000, max_seq_len=64, |
| ) |
| model = DiffusionTransformer(config) |
|
|
| x_t = torch.randn(2, 32, 128) |
| t = torch.tensor([100, 500]) |
| graph_keys = torch.randn(2, 10, 128) |
| graph_values = torch.randn(2, 10, 128) |
|
|
| output = model(x_t=x_t, t=t, graph_keys=graph_keys, graph_values=graph_values) |
| assert output.shape == (2, 32, 128) |
|
|
|
|
| class TestAamDiffusionModel: |
| """Test complete AAM Diffusion Model.""" |
|
|
| def test_model_creation_tiny(self): |
| """Test creating a tiny model.""" |
| config = get_default_config("tiny") |
| model = AamDiffusionModel(config) |
| n_params = model.get_num_params() |
| assert n_params > 0 |
| assert n_params < 100e6 |
|
|
| def test_forward_training(self): |
| """Test training forward pass.""" |
| config = get_default_config("tiny") |
| model = AamDiffusionModel(config) |
| model.eval() |
|
|
| token_ids = torch.randint(0, config.model.vocab_size, (2, 32)) |
| timestep = torch.randint(0, config.diffusion.n_timesteps, (2,)) |
|
|
| with torch.no_grad(): |
| predicted, noise = model(token_ids=token_ids, timestep=timestep) |
|
|
| assert predicted.shape == noise.shape |
|
|
| def test_loss_computation(self): |
| """Test loss computation.""" |
| config = get_default_config("tiny") |
| model = AamDiffusionModel(config) |
| model.eval() |
|
|
| token_ids = torch.randint(0, config.model.vocab_size, (2, 32)) |
| timestep = torch.randint(0, config.diffusion.n_timesteps, (2,)) |
|
|
| with torch.no_grad(): |
| predicted, noise = model(token_ids=token_ids, timestep=timestep) |
| loss = model.compute_loss(predicted, noise, timestep) |
|
|
| assert loss.item() >= 0 |
| assert not torch.isnan(loss) |
|
|
| def test_save_load(self, tmp_path): |
| """Test model save/load.""" |
| config = get_default_config("tiny") |
| model = AamDiffusionModel(config) |
|
|
| path = str(tmp_path / "model.pt") |
| model.save(path) |
|
|
| loaded = AamDiffusionModel.load(path) |
| assert loaded.config.model.d_model == config.model.d_model |
|
|
|
|
| class TestGraphEncoder: |
| """Test Graph Conditioning Encoder.""" |
|
|
| def test_evidence_encoding(self): |
| """Test encoding evidence nodes.""" |
| config = GraphEncoderConfig(d_graph=128, n_graph_layers=2, n_graph_heads=4) |
| encoder = GraphConditioningEncoder(config, vocab_size=1000) |
|
|
| evidence_ids = torch.randint(0, 1000, (2, 5, 16)) |
| evidence_conf = torch.tensor([[0.8, 0.6, 0.9, 0.7, 0.5], |
| [0.7, 0.8, 0.6, 0.9, 0.5]]) |
|
|
| result = encoder(evidence_ids=evidence_ids, evidence_confidence=evidence_conf) |
| assert "keys" in result |
| assert "values" in result |
|
|
| def test_no_input(self): |
| """Test encoder with no graph data (should return zeros).""" |
| config = GraphEncoderConfig(d_graph=128, n_graph_layers=2, n_graph_heads=4) |
| encoder = GraphConditioningEncoder(config, vocab_size=1000) |
|
|
| result = encoder() |
| assert "keys" in result |
|
|
|
|
| if __name__ == "__main__": |
| pytest.main([__file__, "-v"]) |
|
|