Spaces:
Sleeping
Sleeping
import os | |
import torch | |
import argparse | |
import gradio as gr | |
import openai | |
from zipfile import ZipFile | |
import requests | |
import se_extractor | |
from api import BaseSpeakerTTS, ToneColorConverter | |
import langid | |
import traceback | |
from dotenv import load_dotenv | |
# Load environment variables | |
load_dotenv() | |
# Function to download and extract checkpoints | |
def download_and_extract_checkpoints(): | |
zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip" | |
zip_path = "checkpoints.zip" | |
if not os.path.exists("checkpoints"): | |
print("Downloading checkpoints...") | |
response = requests.get(zip_url, stream=True) | |
with open(zip_path, "wb") as zip_file: | |
for chunk in response.iter_content(chunk_size=8192): | |
if chunk: | |
zip_file.write(chunk) | |
print("Extracting checkpoints...") | |
with ZipFile(zip_path, "r") as zip_ref: | |
zip_ref.extractall(".") | |
os.remove(zip_path) | |
print("Checkpoints are ready.") | |
# Call the function to ensure checkpoints are available | |
download_and_extract_checkpoints() | |
# Initialize OpenAI API key | |
openai.api_key = os.getenv("OPENAI_API_KEY") | |
if not openai.api_key: | |
raise ValueError("Please set the OPENAI_API_KEY environment variable.") | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--share", action='store_true', default=False, help="make link public") | |
args = parser.parse_args() | |
# Define paths to checkpoints | |
en_ckpt_base = 'checkpoints/base_speakers/EN' | |
zh_ckpt_base = 'checkpoints/base_speakers/ZH' | |
ckpt_converter = 'checkpoints/converter' | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
output_dir = 'outputs' | |
os.makedirs(output_dir, exist_ok=True) | |
# Load TTS models | |
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device) | |
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth') | |
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device) | |
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth') | |
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) | |
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') | |
# Load speaker embeddings | |
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device) | |
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device) | |
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device) | |
# Extract speaker embedding from the default Mickey Mouse audio | |
default_speaker_audio = "resources/output.wav" | |
try: | |
target_se, _ = se_extractor.get_se( | |
default_speaker_audio, | |
tone_color_converter, | |
target_dir='processed', | |
vad=True | |
) | |
print("Speaker embedding extracted successfully.") | |
except Exception as e: | |
raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}") | |
# Supported languages | |
supported_languages = ['zh', 'en'] | |
def predict(audio_file_pth, agree): | |
text_hint = '' | |
synthesized_audio_path = None | |
# Agree with the terms | |
if not agree: | |
text_hint += '[ERROR] Please accept the Terms & Conditions!\n' | |
return (text_hint, None) | |
# Check if audio file is provided | |
if audio_file_pth is not None: | |
speaker_wav = audio_file_pth | |
else: | |
text_hint += "[ERROR] Please record your voice using the Microphone.\n" | |
return (text_hint, None) | |
# Transcribe audio to text using OpenAI Whisper | |
try: | |
with open(speaker_wav, 'rb') as audio_file: | |
transcription_response = openai.Audio.transcribe( | |
model="whisper-1", | |
file=audio_file, | |
response_format='text' | |
) | |
input_text = transcription_response.strip() | |
print(f"Transcribed Text: {input_text}") | |
except Exception as e: | |
text_hint += f"[ERROR] Transcription failed: {str(e)}\n" | |
return (text_hint, None) | |
if len(input_text) == 0: | |
text_hint += "[ERROR] No speech detected in the audio.\n" | |
return (text_hint, None) | |
# Detect language | |
language_predicted = langid.classify(input_text)[0].strip() | |
print(f"Detected language: {language_predicted}") | |
if language_predicted not in supported_languages: | |
text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n" | |
return (text_hint, None) | |
# Select TTS model based on language | |
if language_predicted == "zh": | |
tts_model = zh_base_speaker_tts | |
language = 'Chinese' | |
speaker_style = 'default' | |
else: | |
tts_model = en_base_speaker_tts | |
language = 'English' | |
speaker_style = 'default' | |
# Generate response using OpenAI GPT-4 | |
try: | |
response = openai.ChatCompletion.create( | |
model="gpt-4o-mini", | |
messages=[ | |
{"role": "system", "content": "You are Mickey Mouse, a friendly and cheerful character who responds to children's queries in a simple and engaging manner. Please keep your response up to 200 characters."}, | |
{"role": "user", "content": input_text} | |
], | |
max_tokens=200, | |
temperature=0.7, | |
) | |
reply_text = response['choices'][0]['message']['content'].strip() | |
print(f"GPT-4 Reply: {reply_text}") | |
except Exception as e: | |
text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n" | |
return (text_hint, None) | |
# Synthesize reply text to audio | |
try: | |
src_path = os.path.join(output_dir, 'tmp_reply.wav') | |
tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language) | |
print(f"Audio synthesized and saved to {src_path}") | |
save_path = os.path.join(output_dir, 'output_reply.wav') | |
tone_color_converter.convert( | |
audio_src_path=src_path, | |
src_se=en_source_default_se if language == 'English' else zh_source_se, | |
tgt_se=target_se, | |
output_path=save_path, | |
message="@MickeyMouse" | |
) | |
print(f"Tone color conversion completed and saved to {save_path}") | |
text_hint += "Response generated successfully.\n" | |
synthesized_audio_path = save_path | |
except Exception as e: | |
text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n" | |
traceback.print_exc() | |
return (text_hint, None) | |
return (text_hint, synthesized_audio_path) | |
with gr.Blocks(analytics_enabled=False) as demo: | |
gr.Markdown("# Mickey Mouse Voice Assistant") | |
with gr.Row(): | |
with gr.Column(): | |
audio_input = gr.Audio( | |
source="microphone", | |
type="filepath", | |
label="Record Your Voice", | |
info="Click the microphone button to record your voice." | |
) | |
tos_checkbox = gr.Checkbox( | |
label="Agree to Terms & Conditions", | |
value=False, | |
info="I agree to the terms of service." | |
) | |
submit_button = gr.Button("Send") | |
with gr.Column(): | |
info_output = gr.Textbox( | |
label="Info", | |
interactive=False, | |
lines=4, | |
) | |
audio_output = gr.Audio( | |
label="Mickey's Response", | |
interactive=False, | |
autoplay=True, | |
) | |
submit_button.click( | |
predict, | |
inputs=[audio_input, tos_checkbox], | |
outputs=[info_output, audio_output] | |
) | |
# Launch the Gradio app | |
demo.queue() | |
demo.launch( | |
server_name="0.0.0.0", | |
server_port=int(os.environ.get("PORT", 7860)), | |
debug=True, | |
show_api=True, | |
share=False | |
) | |