Spaces:
Sleeping
Sleeping
| # app.py | |
| import streamlit as st | |
| import numpy as np | |
| import matplotlib.pyplot as plt | |
| import os | |
| from model.inference import TTSInference | |
| # Page Config | |
| st.set_page_config(page_title="My Custom TTS Engine", layout="wide") | |
| st.title("🎙️ Custom Architecture TTS Playground") | |
| st.markdown("This project demonstrates a custom PyTorch implementation of a Transformer-based TTS.") | |
| # Sidebar for Model Controls | |
| with st.sidebar: | |
| st.header("Model Settings") | |
| checkpoint = st.selectbox("Select Checkpoint", [ | |
| "checkpoints/checkpoint_epoch_50c.pth", | |
| "checkpoints/checkpoint_epoch_3c.pth", | |
| "checkpoints/checkpoint_epoch_8.pth" | |
| ]) | |
| # Force CPU for Hugging Face free tier to prevent CUDA errors | |
| device = st.radio("Device", ["cpu"]) | |
| st.info("Load a specific training checkpoint to compare progress.") | |
| # --- CRITICAL FIX FOR CLOUD: Cache the model --- | |
| def load_engine(ckpt_path, dev): | |
| if not os.path.exists(ckpt_path): | |
| return None # Return None if file isn't uploaded yet | |
| return TTSInference(checkpoint_path=ckpt_path, device=dev) | |
| # Initialize the Inference Engine | |
| tts_engine = load_engine(checkpoint, device) | |
| # Main Input Area | |
| text_input = st.text_area("Enter Text to Speak:", "Deep learning is fascinating.", height=100) | |
| col1, col2 = st.columns([1, 2]) | |
| with col1: | |
| if st.button("Generate Audio", type="primary"): | |
| if tts_engine is None: | |
| st.error(f"⚠️ Error: Could not find '{checkpoint}'. Did you upload it to the 'checkpoints' folder on Hugging Face?") | |
| else: | |
| with st.spinner("Running Inference..."): | |
| # Call your backend | |
| audio_data, sample_rate, mel_spec = tts_engine.predict(text_input) | |
| # Play Audio | |
| st.success("Generation Complete!") | |
| st.audio(audio_data, sample_rate=sample_rate) | |
| # --- VISUALIZATION --- | |
| st.subheader("Mel Spectrogram Analysis") | |
| fig, ax = plt.subplots(figsize=(10, 3)) | |
| im = ax.imshow(mel_spec, aspect='auto', origin='lower', cmap='inferno') | |
| plt.colorbar(im, ax=ax) | |
| plt.title("Generated Mel Spectrogram") | |
| plt.xlabel("Time Frames") | |
| plt.ylabel("Mel Channels") | |
| st.pyplot(fig) | |
| with col2: | |
| st.subheader("Architecture Details") | |
| st.code(""" | |
| class TextToMel(nn.Module): | |
| def __init__(self): | |
| super().__init__() | |
| self.encoder = TransformerEncoder(...) | |
| self.decoder = TransformerDecoder(...) | |
| def forward(self, text): | |
| # 1. Embed text | |
| # 2. Add Positional Encodings | |
| # 3. Predict Mel Frames | |
| return mel_output | |
| """, language="python") |