nanee-convo / app.py
ygauravyy's picture
Update app.py
fff6648 verified
raw
history blame
7.79 kB
import os
import torch
import argparse
import gradio as gr
import openai
from zipfile import ZipFile
import requests
import se_extractor
from api import BaseSpeakerTTS, ToneColorConverter
import langid
import traceback
from dotenv import load_dotenv
# Load environment variables
load_dotenv()
# Function to download and extract checkpoints
def download_and_extract_checkpoints():
zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip"
zip_path = "checkpoints.zip"
if not os.path.exists("checkpoints"):
print("Downloading checkpoints...")
response = requests.get(zip_url, stream=True)
with open(zip_path, "wb") as zip_file:
for chunk in response.iter_content(chunk_size=8192):
if chunk:
zip_file.write(chunk)
print("Extracting checkpoints...")
with ZipFile(zip_path, "r") as zip_ref:
zip_ref.extractall(".")
os.remove(zip_path)
print("Checkpoints are ready.")
# Call the function to ensure checkpoints are available
download_and_extract_checkpoints()
# Initialize OpenAI API key
openai.api_key = os.getenv("OPENAI_API_KEY")
if not openai.api_key:
raise ValueError("Please set the OPENAI_API_KEY environment variable.")
parser = argparse.ArgumentParser()
parser.add_argument("--share", action='store_true', default=False, help="make link public")
args = parser.parse_args()
# Define paths to checkpoints
en_ckpt_base = 'checkpoints/base_speakers/EN'
zh_ckpt_base = 'checkpoints/base_speakers/ZH'
ckpt_converter = 'checkpoints/converter'
device = 'cuda' if torch.cuda.is_available() else 'cpu'
output_dir = 'outputs'
os.makedirs(output_dir, exist_ok=True)
# Load TTS models
en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device)
en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth')
zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device)
zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth')
tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device)
tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth')
# Load speaker embeddings
en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device)
en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device)
zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device)
# Extract speaker embedding from the default Mickey Mouse audio
default_speaker_audio = "resources/output.wav"
try:
target_se, _ = se_extractor.get_se(
default_speaker_audio,
tone_color_converter,
target_dir='processed',
vad=True
)
print("Speaker embedding extracted successfully.")
except Exception as e:
raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}")
# Supported languages
supported_languages = ['zh', 'en']
def predict(audio_file_pth, agree):
text_hint = ''
synthesized_audio_path = None
# Agree with the terms
if not agree:
text_hint += '[ERROR] Please accept the Terms & Conditions!\n'
return (text_hint, None)
# Check if audio file is provided
if audio_file_pth is not None:
speaker_wav = audio_file_pth
else:
text_hint += "[ERROR] Please record your voice using the Microphone.\n"
return (text_hint, None)
# Transcribe audio to text using OpenAI Whisper
try:
with open(speaker_wav, 'rb') as audio_file:
transcription_response = openai.Audio.transcribe(
model="whisper-1",
file=audio_file,
response_format='text'
)
input_text = transcription_response.strip()
print(f"Transcribed Text: {input_text}")
except Exception as e:
text_hint += f"[ERROR] Transcription failed: {str(e)}\n"
return (text_hint, None)
if len(input_text) == 0:
text_hint += "[ERROR] No speech detected in the audio.\n"
return (text_hint, None)
# Detect language
language_predicted = langid.classify(input_text)[0].strip()
print(f"Detected language: {language_predicted}")
if language_predicted not in supported_languages:
text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n"
return (text_hint, None)
# Select TTS model based on language
if language_predicted == "zh":
tts_model = zh_base_speaker_tts
language = 'Chinese'
speaker_style = 'default'
else:
tts_model = en_base_speaker_tts
language = 'English'
speaker_style = 'default'
# Generate response using OpenAI GPT-4
try:
response = openai.ChatCompletion.create(
model="gpt-4o-mini",
messages=[
{"role": "system", "content": "You are Mickey Mouse, a friendly and cheerful character who responds to children's queries in a simple and engaging manner. Please keep your response up to 200 characters."},
{"role": "user", "content": input_text}
],
max_tokens=200,
temperature=0.7,
)
reply_text = response['choices'][0]['message']['content'].strip()
print(f"GPT-4 Reply: {reply_text}")
except Exception as e:
text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n"
return (text_hint, None)
# Synthesize reply text to audio
try:
src_path = os.path.join(output_dir, 'tmp_reply.wav')
tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language)
print(f"Audio synthesized and saved to {src_path}")
save_path = os.path.join(output_dir, 'output_reply.wav')
tone_color_converter.convert(
audio_src_path=src_path,
src_se=en_source_default_se if language == 'English' else zh_source_se,
tgt_se=target_se,
output_path=save_path,
message="@MickeyMouse"
)
print(f"Tone color conversion completed and saved to {save_path}")
text_hint += "Response generated successfully.\n"
synthesized_audio_path = save_path
except Exception as e:
text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n"
traceback.print_exc()
return (text_hint, None)
return (text_hint, synthesized_audio_path)
with gr.Blocks(analytics_enabled=False) as demo:
gr.Markdown("# Mickey Mouse Voice Assistant")
with gr.Row():
with gr.Column():
audio_input = gr.Audio(
source="microphone",
type="filepath",
label="Record Your Voice",
info="Click the microphone button to record your voice."
)
tos_checkbox = gr.Checkbox(
label="Agree to Terms & Conditions",
value=False,
info="I agree to the terms of service."
)
submit_button = gr.Button("Send")
with gr.Column():
info_output = gr.Textbox(
label="Info",
interactive=False,
lines=4,
)
audio_output = gr.Audio(
label="Mickey's Response",
interactive=False,
autoplay=True,
)
submit_button.click(
predict,
inputs=[audio_input, tos_checkbox],
outputs=[info_output, audio_output]
)
# Launch the Gradio app
demo.queue()
demo.launch(
server_name="0.0.0.0",
server_port=int(os.environ.get("PORT", 7860)),
debug=True,
show_api=True,
share=False
)