Spaces:
Sleeping
Sleeping
| import os | |
| import torch | |
| import argparse | |
| import openai | |
| from zipfile import ZipFile | |
| import requests | |
| import se_extractor | |
| from api import BaseSpeakerTTS, ToneColorConverter | |
| import langid | |
| import traceback | |
| from dotenv import load_dotenv | |
| from fastapi import FastAPI, File, UploadFile, Form | |
| from fastapi.responses import JSONResponse | |
| from fastapi.staticfiles import StaticFiles | |
| import uvicorn | |
| # Load environment variables | |
| load_dotenv() | |
| def download_and_extract_checkpoints(): | |
| zip_url = "https://huggingface.co/camenduru/OpenVoice/resolve/main/checkpoints_1226.zip" | |
| zip_path = "checkpoints.zip" | |
| if not os.path.exists("checkpoints"): | |
| print("Downloading checkpoints...") | |
| response = requests.get(zip_url, stream=True) | |
| with open(zip_path, "wb") as zip_file: | |
| for chunk in response.iter_content(chunk_size=8192): | |
| if chunk: | |
| zip_file.write(chunk) | |
| print("Extracting checkpoints...") | |
| with ZipFile(zip_path, "r") as zip_ref: | |
| zip_ref.extractall(".") | |
| os.remove(zip_path) | |
| print("Checkpoints are ready.") | |
| # Call the function to ensure checkpoints are available | |
| download_and_extract_checkpoints() | |
| # Initialize OpenAI API key | |
| openai.api_key = os.getenv("OPENAI_API_KEY") | |
| if not openai.api_key: | |
| raise ValueError("Please set the OPENAI_API_KEY environment variable.") | |
| en_ckpt_base = 'checkpoints/base_speakers/EN' | |
| zh_ckpt_base = 'checkpoints/base_speakers/ZH' | |
| ckpt_converter = 'checkpoints/converter' | |
| device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
| output_dir = 'outputs' | |
| os.makedirs(output_dir, exist_ok=True) | |
| en_base_speaker_tts = BaseSpeakerTTS(f'{en_ckpt_base}/config.json', device=device) | |
| en_base_speaker_tts.load_ckpt(f'{en_ckpt_base}/checkpoint.pth') | |
| zh_base_speaker_tts = BaseSpeakerTTS(f'{zh_ckpt_base}/config.json', device=device) | |
| zh_base_speaker_tts.load_ckpt(f'{zh_ckpt_base}/checkpoint.pth') | |
| tone_color_converter = ToneColorConverter(f'{ckpt_converter}/config.json', device=device) | |
| tone_color_converter.load_ckpt(f'{ckpt_converter}/checkpoint.pth') | |
| en_source_default_se = torch.load(f'{en_ckpt_base}/en_default_se.pth').to(device) | |
| en_source_style_se = torch.load(f'{en_ckpt_base}/en_style_se.pth').to(device) | |
| zh_source_se = torch.load(f'{zh_ckpt_base}/zh_default_se.pth').to(device) | |
| default_speaker_audio = "resources/output.wav" | |
| try: | |
| target_se, _ = se_extractor.get_se( | |
| default_speaker_audio, | |
| tone_color_converter, | |
| target_dir='processed', | |
| vad=True | |
| ) | |
| print("Speaker embedding extracted successfully.") | |
| except Exception as e: | |
| raise RuntimeError(f"Failed to extract speaker embedding from {default_speaker_audio}: {str(e)}") | |
| supported_languages = ['zh', 'en'] | |
| def predict(audio_file_pth, agree): | |
| text_hint = '' | |
| synthesized_audio_path = None | |
| if not agree: | |
| text_hint += '[ERROR] Please accept the Terms & Conditions!\n' | |
| return (text_hint, None) | |
| if audio_file_pth is not None: | |
| speaker_wav = audio_file_pth | |
| else: | |
| text_hint += "[ERROR] Please provide your voice as an audio file.\n" | |
| return (text_hint, None) | |
| # Transcribe audio to text using OpenAI Whisper | |
| try: | |
| with open(speaker_wav, 'rb') as audio_file: | |
| transcription_response = openai.audio.transcriptions.create( | |
| model="whisper-1", | |
| file=audio_file, | |
| response_format='text' | |
| ) | |
| input_text = transcription_response.strip() | |
| print(f"Transcribed Text: {input_text}") | |
| except Exception as e: | |
| text_hint += f"[ERROR] Transcription failed: {str(e)}\n" | |
| return (text_hint, None) | |
| if len(input_text) == 0: | |
| text_hint += "[ERROR] No speech detected in the audio.\n" | |
| return (text_hint, None) | |
| language_predicted = langid.classify(input_text)[0].strip() | |
| print(f"Detected language: {language_predicted}") | |
| if language_predicted not in supported_languages: | |
| text_hint += f"[ERROR] The detected language '{language_predicted}' is not supported. Supported languages are: {supported_languages}\n" | |
| return (text_hint, None) | |
| if language_predicted == "zh": | |
| tts_model = zh_base_speaker_tts | |
| language = 'Chinese' | |
| speaker_style = 'default' | |
| else: | |
| tts_model = en_base_speaker_tts | |
| language = 'English' | |
| speaker_style = 'default' | |
| try: | |
| response = openai.chat.completions.create( | |
| model="gpt-4o-mini", | |
| messages=[ | |
| {"role": "system", "content": "You are Mickey Mouse, a friendly and cheerful character who responds to children's queries in a simple and engaging manner. Please keep your response up to 200 characters."}, | |
| {"role": "user", "content": input_text} | |
| ], | |
| max_tokens=200, | |
| n=1, | |
| stop=None, | |
| temperature=0.7, | |
| ) | |
| reply_text = response.choices[0].message.content.strip() | |
| print(f"GPT-4 Reply: {reply_text}") | |
| except Exception as e: | |
| text_hint += f"[ERROR] Failed to get response from OpenAI GPT-4: {str(e)}\n" | |
| return (text_hint, None) | |
| try: | |
| src_path = os.path.join(output_dir, 'tmp_reply.wav') | |
| tts_model.tts(reply_text, src_path, speaker=speaker_style, language=language) | |
| print(f"Audio synthesized and saved to {src_path}") | |
| save_path = os.path.join(output_dir, 'output_reply.wav') | |
| tone_color_converter.convert( | |
| audio_src_path=src_path, | |
| src_se=en_source_default_se if language == 'English' else zh_source_se, | |
| tgt_se=target_se, | |
| output_path=save_path, | |
| message="@MickeyMouse" | |
| ) | |
| print(f"Tone color conversion completed and saved to {save_path}") | |
| text_hint += "Response generated successfully.\n" | |
| synthesized_audio_path = save_path | |
| except Exception as e: | |
| text_hint += f"[ERROR] Failed to synthesize audio: {str(e)}\n" | |
| traceback.print_exc() | |
| return (text_hint, None) | |
| return (text_hint, synthesized_audio_path) | |
| app = FastAPI() | |
| # Mount the 'outputs' directory to serve static files | |
| app.mount("/outputs", StaticFiles(directory="outputs"), name="outputs") | |
| async def predict_endpoint(agree: bool = Form(...), audio_file: UploadFile = File(...)): | |
| temp_dir = "temp" | |
| os.makedirs(temp_dir, exist_ok=True) | |
| audio_path = os.path.join(temp_dir, audio_file.filename) | |
| with open(audio_path, "wb") as f: | |
| f.write(await audio_file.read()) | |
| info, audio_output_path = predict(audio_path, agree) | |
| if audio_output_path: | |
| audio_url = f"/outputs/{os.path.basename(audio_output_path)}" | |
| return {"info": info, "audio_path": audio_url} | |
| else: | |
| return {"info": info, "audio_path": None}, 400 | |
| if __name__ == "__main__": | |
| uvicorn.run(app, host="0.0.0.0", port=int(os.environ.get("PORT", 7860))) |