| | import torch
|
| | import soundfile as sf
|
| | import os
|
| | from .model import HexaTransformer
|
| | from .text_encoder import TextEncoder
|
| | from .config import HexaConfig
|
| |
|
| | def run_tiny_test():
|
| | """
|
| | Test the architecture with a tiny config to fit in memory.
|
| | """
|
| | print("Initializing Tiny Hexa Model for Code Verification...")
|
| |
|
| |
|
| | config = HexaConfig(
|
| | dim=512,
|
| | depth=6,
|
| | heads=8,
|
| | dim_head=64,
|
| | num_languages=15
|
| | )
|
| |
|
| | device = "cpu"
|
| | model = HexaTransformer(config)
|
| | model.to(device)
|
| | model.eval()
|
| |
|
| | params = sum(p.numel() for p in model.parameters())
|
| | print(f"Tiny Model Size: {params / 1e6:.2f} Million parameters")
|
| |
|
| |
|
| | text = "Hello world, testing tiny hexa."
|
| | encoder = TextEncoder()
|
| | text_ids = encoder.preprocess(text, lang_code='en').to(device)
|
| | print(f"Encoded text shape: {text_ids.shape}")
|
| |
|
| |
|
| | speaker = torch.tensor([0]).to(device)
|
| | language = torch.tensor([0]).to(device)
|
| | emotion = torch.tensor([0]).to(device)
|
| |
|
| |
|
| | with torch.no_grad():
|
| | output = model(text_ids, speaker, language, emotion)
|
| |
|
| | print(f"Forward pass successful. Output shape: {output.shape}")
|
| |
|
| |
|
| |
|
| |
|
| | dummy_wav = torch.randn(output.shape[1] * 256).numpy()
|
| | sf.write("tiny_output.wav", dummy_wav, config.sample_rate)
|
| | print("Saved tiny_output.wav")
|
| |
|
| | if __name__ == "__main__":
|
| | run_tiny_test()
|
| |
|