Spaces:
Running
Running
import gradio as gr | |
from outetts.v0_1.interface import InterfaceHF | |
import logging | |
import os | |
import tempfile | |
# Import faster-whisper for transcription | |
from faster_whisper import WhisperModel | |
# Configure logging to display information in the terminal | |
logging.basicConfig(level=logging.INFO) | |
logger = logging.getLogger(__name__) | |
# Initialize the OuteTTS interface with the Hugging Face model | |
try: | |
logger.info("Initializing OuteTTS InterfaceHF with model 'OuteAI/OuteTTS-0.1-350M'") | |
interface = InterfaceHF("OuteAI/OuteTTS-0.1-350M") | |
logger.info("Model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load model: {e}") | |
raise e | |
# Initialize the faster-whisper model | |
try: | |
logger.info("Initializing faster-whisper model for transcription.") | |
whisper_model = WhisperModel("tiny", device="cpu", compute_type="int8") | |
logger.info("faster-whisper model loaded successfully.") | |
except Exception as e: | |
logger.error(f"Failed to load faster-whisper model: {e}") | |
raise e | |
def generate_tts(text, temperature, repetition_penalty, max_length, speaker): | |
""" | |
Generates speech from the input text using the OuteTTS model. | |
Parameters: | |
text (str): The input text for TTS. | |
temperature (float): Sampling temperature. | |
repetition_penalty (float): Repetition penalty. | |
max_length (int): Maximum length of the generated audio tokens. | |
speaker (dict): Speaker configuration for voice cloning. | |
Returns: | |
str: Path to the generated audio file. | |
""" | |
logger.info("Received TTS generation request.") | |
logger.info(f"Parameters - Text: {text}, Temperature: {temperature}, Repetition Penalty: {repetition_penalty}, Max Length: {max_length}, Speaker: {speaker is not None}") | |
try: | |
# Due to a typo in interface.py, use 'max_lenght' instead of 'max_length' | |
output = interface.generate( | |
text=text, | |
temperature=temperature, | |
repetition_penalty=repetition_penalty, | |
max_lenght=max_length, # Pass the parameter with typo | |
speaker=speaker | |
) | |
logger.info("TTS generation complete.") | |
# Save the output to a temporary WAV file | |
output_path = os.path.join(tempfile.gettempdir(), "output.wav") | |
output.save(output_path) | |
logger.info(f"Audio saved to {output_path}") | |
return output_path # Gradio will handle the audio playback | |
except Exception as e: | |
logger.error(f"Error during TTS generation: {e}") | |
return None | |
def transcribe_audio(audio_path): | |
""" | |
Transcribes the given audio file using faster-whisper. | |
Parameters: | |
audio_path (str): Path to the audio file. | |
Returns: | |
str: Transcribed text. | |
""" | |
logger.info(f"Transcribing audio file: {audio_path}") | |
segments, info = whisper_model.transcribe(audio_path) | |
transcript = " ".join([segment.text for segment in segments]) | |
logger.info(f"Transcription complete: {transcript}") | |
return transcript | |
def create_speaker_with_transcription(audio_file): | |
""" | |
Creates a custom speaker from a reference audio file by automatically transcribing it. | |
Parameters: | |
audio_file (file): Uploaded reference audio file. | |
Returns: | |
dict: Speaker configuration. | |
""" | |
logger.info("Received Voice Cloning request with audio file.") | |
try: | |
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_audio: | |
temp_audio_path = temp_audio.name | |
# Save uploaded audio to temporary file | |
with open(temp_audio_path, "wb") as f: | |
f.write(audio_file.read()) | |
logger.info(f"Reference audio saved to {temp_audio_path}") | |
# Transcribe the audio file | |
transcript = transcribe_audio(temp_audio_path) | |
if not transcript.strip(): | |
logger.error("Transcription resulted in empty text.") | |
return None | |
# Create speaker using the transcribed text | |
speaker = interface.create_speaker(temp_audio_path, transcript) | |
logger.info("Speaker created successfully.") | |
# Clean up the temporary audio file | |
os.remove(temp_audio_path) | |
logger.info(f"Temporary audio file {temp_audio_path} removed.") | |
return speaker | |
except Exception as e: | |
logger.error(f"Error during speaker creation: {e}") | |
return None | |
# Define the Gradio Blocks interface | |
with gr.Blocks() as demo: | |
gr.Markdown("# π€ OuteTTS - Text to Speech Interface") | |
gr.Markdown( | |
""" | |
Generate speech from text using the **OuteTTS-0.1-350M** model. | |
**Key Features:** | |
- Pure language modeling approach to TTS | |
- Voice cloning capabilities with automatic transcription | |
- Compatible with LLaMa architecture | |
""" | |
) | |
with gr.Tab("Basic TTS"): | |
with gr.Row(): | |
text_input = gr.Textbox( | |
label="π Text Input", | |
placeholder="Enter the text for TTS generation", | |
lines=3 | |
) | |
with gr.Row(): | |
temperature = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.1, | |
step=0.01, | |
label="π‘οΈ Temperature" | |
) | |
repetition_penalty = gr.Slider( | |
minimum=0.5, | |
maximum=2.0, | |
value=1.1, | |
step=0.1, | |
label="π Repetition Penalty" | |
) | |
max_length = gr.Slider( | |
minimum=256, | |
maximum=4096, | |
value=1024, | |
step=256, | |
label="π Max Length" | |
) | |
generate_button = gr.Button("π Generate Speech") | |
output_audio = gr.Audio( | |
label="π§ Generated Speech", | |
type="filepath" # Expecting a file path to the audio | |
) | |
# Define the button click event for Basic TTS | |
generate_button.click( | |
fn=generate_tts, | |
inputs=[text_input, temperature, repetition_penalty, max_length, None], | |
outputs=output_audio | |
) | |
with gr.Tab("Voice Cloning"): | |
with gr.Row(): | |
reference_audio = gr.Audio( | |
label="π Reference Audio", | |
type="file", | |
source="upload", | |
optional=False | |
) | |
create_speaker_button = gr.Button("π€ Create Speaker") | |
speaker_info = gr.JSON(label="ποΈ Speaker Configuration", interactive=False) | |
with gr.Row(): | |
generate_cloned_speech = gr.Textbox( | |
label="π Text Input", | |
placeholder="Enter the text for TTS generation with cloned voice", | |
lines=3 | |
) | |
with gr.Row(): | |
temperature_clone = gr.Slider( | |
minimum=0.1, | |
maximum=1.0, | |
value=0.1, | |
step=0.01, | |
label="π‘οΈ Temperature" | |
) | |
repetition_penalty_clone = gr.Slider( | |
minimum=0.5, | |
maximum=2.0, | |
value=1.1, | |
step=0.1, | |
label="π Repetition Penalty" | |
) | |
max_length_clone = gr.Slider( | |
minimum=256, | |
maximum=4096, | |
value=1024, | |
step=256, | |
label="π Max Length" | |
) | |
generate_cloned_button = gr.Button("π Generate Cloned Speech") | |
output_cloned_audio = gr.Audio( | |
label="π§ Generated Cloned Speech", | |
type="filepath" # Expecting a file path to the audio | |
) | |
# Define the button click event for creating a speaker | |
create_speaker_button.click( | |
fn=create_speaker_with_transcription, | |
inputs=[reference_audio], | |
outputs=speaker_info | |
) | |
# Define the button click event for generating speech with the cloned voice | |
generate_cloned_button.click( | |
fn=generate_tts, | |
inputs=[generate_cloned_speech, temperature_clone, repetition_penalty_clone, max_length_clone, speaker_info], | |
outputs=output_cloned_audio | |
) | |
gr.Markdown( | |
""" | |
--- | |
**Technical Blog:** [OuteTTS-0.1-350M](https://www.outeai.com/blog/OuteTTS-0.1-350M) | |
**Credits:** | |
- [WavTokenizer](https://github.com/jishengpeng/WavTokenizer) | |
- [CTC Forced Alignment](https://pytorch.org/audio/stable/tutorials/ctc_forced_alignment_api_tutorial.html) | |
- [faster-whisper](https://github.com/guillaumekln/faster-whisper) | |
""" | |
) | |
# Launch the Gradio app | |
if __name__ == "__main__": | |
demo.launch() | |