|
|
| """Test script for Zenith-7B model"""
|
|
|
| import torch
|
| import unittest
|
| from pathlib import Path
|
| import sys
|
|
|
| sys.path.append(str(Path(__file__).parent))
|
|
|
| from configs.zenith_config import get_7b_config
|
| from models.zenith_model import ZenithForCausalLM, ZenithModel
|
| from data.advanced_tokenizer import AdvancedTokenizer
|
|
|
|
|
| class TestZenith7B(unittest.TestCase):
|
| """Test suite for Zenith-7B model."""
|
|
|
| @classmethod
|
| def setUpClass(cls):
|
| """Set up test fixtures."""
|
| cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
|
| cls.config = get_7b_config()
|
| cls.config.vocab_size = 32000
|
|
|
|
|
| cls.model = ZenithModel(cls.config)
|
| cls.model.to(cls.device)
|
| cls.model.eval()
|
|
|
|
|
| cls.tokenizer = AdvancedTokenizer(vocab_size=32000)
|
|
|
| def test_model_creation(self):
|
| """Test model can be created."""
|
| self.assertIsNotNone(self.model)
|
| self.assertTrue(hasattr(self.model, 'transformer'))
|
|
|
| def test_forward_pass(self):
|
| """Test forward pass works."""
|
| batch_size = 2
|
| seq_len = 32
|
| input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
|
|
| with torch.no_grad():
|
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
| self.assertIsNotNone(outputs.logits)
|
| self.assertEqual(outputs.logits.shape[0], batch_size)
|
| self.assertEqual(outputs.logits.shape[1], seq_len)
|
| self.assertEqual(outputs.logits.shape[2], self.config.vocab_size)
|
|
|
| def test_moe_activation(self):
|
| """Test MoE layers are active when configured."""
|
| if self.config.num_experts > 1:
|
|
|
| moe_layers = [m for m in self.model.modules() if hasattr(m, 'num_experts')]
|
| self.assertGreater(len(moe_layers), 0)
|
|
|
| def test_generation(self):
|
| """Test text generation."""
|
| prompt = "Hello, world!"
|
| input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
|
|
|
| with torch.no_grad():
|
| outputs = self.model.generate(
|
| input_ids,
|
| max_new_tokens=20,
|
| temperature=0.8,
|
| do_sample=True
|
| )
|
|
|
| generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
|
| self.assertIsInstance(generated, str)
|
| self.assertGreater(len(generated), len(prompt))
|
|
|
| def test_loss_computation(self):
|
| """Test loss computation with labels."""
|
| batch_size = 2
|
| seq_len = 32
|
| input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
|
|
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
|
| self.assertIsNotNone(outputs.loss)
|
| self.assertTrue(torch.isfinite(outputs.loss))
|
|
|
| def test_multi_task_outputs(self):
|
| """Test multi-task learning outputs when EQ adapter is enabled."""
|
| if self.config.use_eq_adapter:
|
| batch_size = 2
|
| seq_len = 32
|
| input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| attention_mask = torch.ones(batch_size, seq_len).to(self.device)
|
|
|
| outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
|
|
|
|
|
| self.assertTrue(hasattr(outputs, 'emotion_logits') or outputs.emotion_logits is not None)
|
| self.assertTrue(hasattr(outputs, 'frustration_logits') or outputs.frustration_logits is not None)
|
|
|
| def test_gradient_flow(self):
|
| """Test gradients flow correctly."""
|
| batch_size = 1
|
| seq_len = 16
|
| input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
| labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
|
|
|
| self.model.train()
|
| outputs = self.model(input_ids=input_ids, labels=labels)
|
| loss = outputs.loss
|
| loss.backward()
|
|
|
|
|
| has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad)
|
| self.assertTrue(has_grad)
|
|
|
|
|
| def run_tests():
|
| """Run all tests and report results."""
|
| print("=" * 60)
|
| print("Zenith-7B Model Test Suite")
|
| print("=" * 60)
|
|
|
|
|
| loader = unittest.TestLoader()
|
| suite = loader.loadTestsFromTestCase(TestZenith7B)
|
|
|
|
|
| runner = unittest.TextTestRunner(verbosity=2)
|
| result = runner.run(suite)
|
|
|
|
|
| print("\n" + "=" * 60)
|
| print("Test Summary:")
|
| print(f" Tests run: {result.testsRun}")
|
| print(f" Failures: {len(result.failures)}")
|
| print(f" Errors: {len(result.errors)}")
|
| print(f" Success: {result.wasSuccessful()}")
|
| print("=" * 60)
|
|
|
| return result.wasSuccessful()
|
|
|
|
|
| if __name__ == "__main__":
|
| success = run_tests()
|
| sys.exit(0 if success else 1) |