My-TTS-Streamlit / model /inference.py
Mohit0708's picture
Upload 24 files
be29b5b verified
import torch
import torchaudio
from model.network import MiniTTS
from model.dataset import TextProcessor # We reuse the text logic we already wrote!
class TTSInference:
def __init__(self, checkpoint_path, device='cpu'):
self.device = device
self.model = self.load_model(checkpoint_path)
print(f"Model loaded from {checkpoint_path}")
def load_model(self, path):
# 1. Initialize the same architecture as training
model = MiniTTS(num_chars=40, num_mels=80)
# 2. Load the weights
# map_location ensures it loads on CPU even if trained on GPU
state_dict = torch.load(path, map_location=self.device)
model.load_state_dict(state_dict)
return model.eval().to(self.device)
def predict(self, text):
# 1. Text Preprocessing
text_tensor = TextProcessor.text_to_sequence(text).unsqueeze(0).to(self.device)
# 2. Autoregressive Inference (The Loop)
# We start with ONE silent frame. The model predicts the next, and we feed it back.
with torch.no_grad():
# Start with [Batch, Time=1, Mels=80] of zeros
decoder_input = torch.zeros(1, 1, 80).to(self.device)
# Generate 150 frames (about 1.5 seconds of audio)
# You can increase this range for longer sentences
for _ in range(150):
# Ask model to predict based on what we have so far
prediction = self.model(text_tensor, decoder_input)
# Take ONLY the newest frame it predicted (the last one)
new_frame = prediction[:, -1:, :]
# Add it to our growing list of frames
decoder_input = torch.cat([decoder_input, new_frame], dim=1)
# The result is our generated spectrogram
# Shape: [1, 151, 80] -> [1, 80, 151]
mel_spec = decoder_input.transpose(1, 2)
# 3. Vocoder (Spectrogram -> Audio)
# Inverse Mel Scale
inverse_mel_scaler = torchaudio.transforms.InverseMelScale(
n_stft=513, n_mels=80, sample_rate=22050
).to(self.device)
linear_spec = inverse_mel_scaler(mel_spec)
# Griffin-Lim
griffin_lim = torchaudio.transforms.GriffinLim(n_fft=1024, n_iter=32).to(self.device)
audio = griffin_lim(linear_spec)
return audio.squeeze(0).cpu().numpy(), 22050, mel_spec.squeeze(0).cpu().numpy()