OuteTTS-DEMO / app.py
drewThomasson's picture
Update app.py
a2dc963 verified
raw
history blame
8.98 kB
import gradio as gr
from outetts.v0_1.interface import InterfaceHF
import logging
import os
import tempfile
# Import faster-whisper for transcription
from faster_whisper import WhisperModel
# Configure logging to display information in the terminal
logging.basicConfig(level=logging.INFO)
logger = logging.getLogger(__name__)
# Initialize the OuteTTS interface with the Hugging Face model
try:
logger.info("Initializing OuteTTS InterfaceHF with model 'OuteAI/OuteTTS-0.1-350M'")
interface = InterfaceHF("OuteAI/OuteTTS-0.1-350M")
logger.info("Model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load model: {e}")
raise e
# Initialize the faster-whisper model
try:
logger.info("Initializing faster-whisper model for transcription.")
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8")
logger.info("faster-whisper model loaded successfully.")
except Exception as e:
logger.error(f"Failed to load faster-whisper model: {e}")
raise e
def generate_tts(text, temperature, repetition_penalty, max_length, speaker):
"""
Generates speech from the input text using the OuteTTS model.
Parameters:
text (str): The input text for TTS.
temperature (float): Sampling temperature.
repetition_penalty (float): Repetition penalty.
max_length (int): Maximum length of the generated audio tokens.
speaker (dict): Speaker configuration for voice cloning.
Returns:
str: Path to the generated audio file.
"""
logger.info("Received TTS generation request.")
logger.info(f"Parameters - Text: {text}, Temperature: {temperature}, Repetition Penalty: {repetition_penalty}, Max Length: {max_length}, Speaker: {speaker is not None}")
try:
# Due to a typo in interface.py, use 'max_lenght' instead of 'max_length'
output = interface.generate(
text=text,
temperature=temperature,
repetition_penalty=repetition_penalty,
max_lenght=max_length, # Pass the parameter with typo
speaker=speaker
)
logger.info("TTS generation complete.")
# Save the output to a temporary WAV file
output_path = os.path.join(tempfile.gettempdir(), "output.wav")
output.save(output_path)
logger.info(f"Audio saved to {output_path}")
return output_path # Gradio will handle the audio playback
except Exception as e:
logger.error(f"Error during TTS generation: {e}")
return None
def transcribe_audio(audio_path):
"""
Transcribes the given audio file using faster-whisper.
Parameters:
audio_path (str): Path to the audio file.
Returns:
str: Transcribed text.
"""
logger.info(f"Transcribing audio file: {audio_path}")
segments, info = whisper_model.transcribe(audio_path)
transcript = " ".join([segment.text for segment in segments])
logger.info(f"Transcription complete: {transcript}")
return transcript
def create_speaker_with_transcription(audio_file):
"""
Creates a custom speaker from a reference audio file by automatically transcribing it.
Parameters:
audio_file (file): Uploaded reference audio file.
Returns:
dict: Speaker configuration.
"""
logger.info("Received Voice Cloning request with audio file.")
try:
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio:
temp_audio_path = temp_audio.name
# Save uploaded audio to temporary file
with open(temp_audio_path, "wb") as f:
f.write(audio_file.read())
logger.info(f"Reference audio saved to {temp_audio_path}")
# Transcribe the audio file
transcript = transcribe_audio(temp_audio_path)
if not transcript.strip():
logger.error("Transcription resulted in empty text.")
return None
# Create speaker using the transcribed text
speaker = interface.create_speaker(temp_audio_path, transcript)
logger.info("Speaker created successfully.")
# Clean up the temporary audio file
os.remove(temp_audio_path)
logger.info(f"Temporary audio file {temp_audio_path} removed.")
return speaker
except Exception as e:
logger.error(f"Error during speaker creation: {e}")
return None
# Define the Gradio Blocks interface
with gr.Blocks() as demo:
gr.Markdown("# 🎀 OuteTTS - Text to Speech Interface")
gr.Markdown(
"""
Generate speech from text using the **OuteTTS-0.1-350M** model.
**Key Features:**
- Pure language modeling approach to TTS
- Voice cloning capabilities with automatic transcription
- Compatible with LLaMa architecture
"""
)
with gr.Tab("Basic TTS"):
with gr.Row():
text_input = gr.Textbox(
label="πŸ“„ Text Input",
placeholder="Enter the text for TTS generation",
lines=3
)
with gr.Row():
temperature = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.1,
step=0.01,
label="🌑️ Temperature"
)
repetition_penalty = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.1,
step=0.1,
label="πŸ” Repetition Penalty"
)
max_length = gr.Slider(
minimum=256,
maximum=4096,
value=1024,
step=256,
label="πŸ“ Max Length"
)
generate_button = gr.Button("πŸ”Š Generate Speech")
output_audio = gr.Audio(
label="🎧 Generated Speech",
type="filepath" # Expecting a file path to the audio
)
# Define the button click event for Basic TTS
generate_button.click(
fn=generate_tts,
inputs=[text_input, temperature, repetition_penalty, max_length, None],
outputs=output_audio
)
with gr.Tab("Voice Cloning"):
with gr.Row():
reference_audio = gr.Audio(
label="πŸ”Š Reference Audio",
type="file",
source="upload",
optional=False
)
create_speaker_button = gr.Button("🎀 Create Speaker")
speaker_info = gr.JSON(label="πŸ—‚οΈ Speaker Configuration", interactive=False)
with gr.Row():
generate_cloned_speech = gr.Textbox(
label="πŸ“„ Text Input",
placeholder="Enter the text for TTS generation with cloned voice",
lines=3
)
with gr.Row():
temperature_clone = gr.Slider(
minimum=0.1,
maximum=1.0,
value=0.1,
step=0.01,
label="🌑️ Temperature"
)
repetition_penalty_clone = gr.Slider(
minimum=0.5,
maximum=2.0,
value=1.1,
step=0.1,
label="πŸ” Repetition Penalty"
)
max_length_clone = gr.Slider(
minimum=256,
maximum=4096,
value=1024,
step=256,
label="πŸ“ Max Length"
)
generate_cloned_button = gr.Button("πŸ”Š Generate Cloned Speech")
output_cloned_audio = gr.Audio(
label="🎧 Generated Cloned Speech",
type="filepath" # Expecting a file path to the audio
)
# Define the button click event for creating a speaker
create_speaker_button.click(
fn=create_speaker_with_transcription,
inputs=[reference_audio],
outputs=speaker_info
)
# Define the button click event for generating speech with the cloned voice
generate_cloned_button.click(
fn=generate_tts,
inputs=[generate_cloned_speech, temperature_clone, repetition_penalty_clone, max_length_clone, speaker_info],
outputs=output_cloned_audio
)
gr.Markdown(
"""
---
**Technical Blog:** [OuteTTS-0.1-350M](https://www.outeai.com/blog/OuteTTS-0.1-350M)
**Credits:**
- [WavTokenizer](https://github.com/jishengpeng/WavTokenizer)
- [CTC Forced Alignment](https://pytorch.org/audio/stable/tutorials/ctc_forced_alignment_api_tutorial.html)
- [faster-whisper](https://github.com/guillaumekln/faster-whisper)
"""
)
# Launch the Gradio app
if __name__ == "__main__":
demo.launch()