Zenith-7b-V1 / test_model.py
Zandy-Wandy's picture
Upload Zenith-7B model
8d18b7c verified
#!/usr/bin/env python3
"""Test script for Zenith-7B model"""
import torch
import unittest
from pathlib import Path
import sys
sys.path.append(str(Path(__file__).parent))
from configs.zenith_config import get_7b_config
from models.zenith_model import ZenithForCausalLM, ZenithModel
from data.advanced_tokenizer import AdvancedTokenizer
class TestZenith7B(unittest.TestCase):
"""Test suite for Zenith-7B model."""
@classmethod
def setUpClass(cls):
"""Set up test fixtures."""
cls.device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
cls.config = get_7b_config()
cls.config.vocab_size = 32000 # Test vocab size
# Create small test model
cls.model = ZenithModel(cls.config)
cls.model.to(cls.device)
cls.model.eval()
# Create tokenizer
cls.tokenizer = AdvancedTokenizer(vocab_size=32000)
def test_model_creation(self):
"""Test model can be created."""
self.assertIsNotNone(self.model)
self.assertTrue(hasattr(self.model, 'transformer'))
def test_forward_pass(self):
"""Test forward pass works."""
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
with torch.no_grad():
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
self.assertIsNotNone(outputs.logits)
self.assertEqual(outputs.logits.shape[0], batch_size)
self.assertEqual(outputs.logits.shape[1], seq_len)
self.assertEqual(outputs.logits.shape[2], self.config.vocab_size)
def test_moe_activation(self):
"""Test MoE layers are active when configured."""
if self.config.num_experts > 1:
# Check that MoE layers exist
moe_layers = [m for m in self.model.modules() if hasattr(m, 'num_experts')]
self.assertGreater(len(moe_layers), 0)
def test_generation(self):
"""Test text generation."""
prompt = "Hello, world!"
input_ids = self.tokenizer.encode(prompt, return_tensors="pt").to(self.device)
with torch.no_grad():
outputs = self.model.generate(
input_ids,
max_new_tokens=20,
temperature=0.8,
do_sample=True
)
generated = self.tokenizer.decode(outputs[0], skip_special_tokens=True)
self.assertIsInstance(generated, str)
self.assertGreater(len(generated), len(prompt))
def test_loss_computation(self):
"""Test loss computation with labels."""
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask, labels=labels)
self.assertIsNotNone(outputs.loss)
self.assertTrue(torch.isfinite(outputs.loss))
def test_multi_task_outputs(self):
"""Test multi-task learning outputs when EQ adapter is enabled."""
if self.config.use_eq_adapter:
batch_size = 2
seq_len = 32
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
attention_mask = torch.ones(batch_size, seq_len).to(self.device)
outputs = self.model(input_ids=input_ids, attention_mask=attention_mask)
# Check for emotion and frustration logits if EQ adapter is enabled
self.assertTrue(hasattr(outputs, 'emotion_logits') or outputs.emotion_logits is not None)
self.assertTrue(hasattr(outputs, 'frustration_logits') or outputs.frustration_logits is not None)
def test_gradient_flow(self):
"""Test gradients flow correctly."""
batch_size = 1
seq_len = 16
input_ids = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
labels = torch.randint(0, self.config.vocab_size, (batch_size, seq_len)).to(self.device)
self.model.train()
outputs = self.model(input_ids=input_ids, labels=labels)
loss = outputs.loss
loss.backward()
# Check that gradients exist
has_grad = any(p.grad is not None for p in self.model.parameters() if p.requires_grad)
self.assertTrue(has_grad)
def run_tests():
"""Run all tests and report results."""
print("=" * 60)
print("Zenith-7B Model Test Suite")
print("=" * 60)
# Create test suite
loader = unittest.TestLoader()
suite = loader.loadTestsFromTestCase(TestZenith7B)
# Run tests
runner = unittest.TextTestRunner(verbosity=2)
result = runner.run(suite)
# Summary
print("\n" + "=" * 60)
print("Test Summary:")
print(f" Tests run: {result.testsRun}")
print(f" Failures: {len(result.failures)}")
print(f" Errors: {len(result.errors)}")
print(f" Success: {result.wasSuccessful()}")
print("=" * 60)
return result.wasSuccessful()
if __name__ == "__main__":
success = run_tests()
sys.exit(0 if success else 1)