import os import uuid import torch import re import gradio as gr import torchaudio from audiocraft.models import MusicGen from audiocraft.data.audio import audio_write from huggingface_hub import hf_hub_download, snapshot_download from TTS.tts.configs.xtts_config import XttsConfig from TTS.tts.models.xtts import Xtts from vinorm import TTSnorm from fastapi import FastAPI, File, UploadFile from fastapi.responses import FileResponse, JSONResponse import librosa import soundfile as sf from langdetect import detect import numpy as np import tempfile from pathlib import Path from tqdm import tqdm import asyncio device = "cuda" if torch.cuda.is_available() else "cpu" musicgen_model = MusicGen.get_pretrained("facebook/musicgen-small", device=device) checkpoint_dir = "model/" repo_id = "capleaf/viXTTS" os.makedirs(checkpoint_dir, exist_ok=True) required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"] files_in_dir = os.listdir(checkpoint_dir) if not all(file in files_in_dir for file in required_files): snapshot_download(repo_id=repo_id, repo_type="model", local_dir=checkpoint_dir) hf_hub_download(repo_id="coqui/XTTS-v2", filename="speakers_xtts.pth", local_dir=checkpoint_dir) xtts_config = os.path.join(checkpoint_dir, "config.json") config = XttsConfig() config.load_json(xtts_config) xtts_model = Xtts.init_from_config(config).to(device) xtts_model.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False) supported_languages = config.languages if "vi" not in supported_languages: supported_languages.append("vi") def normalize_vietnamese_text(text): return TTSnorm(text, unknown=False, lower=False, rule=True).replace("..", ".").replace("!.", "!").replace("?.", "?").replace(" .", ".").replace(" ,", ",").replace('"', "").replace("'", "").replace("AI", "Ây Ai").replace("A.I", "Ây Ai") def analyze_music_for_emotion(music_path): y, sr = librosa.load(music_path) tempo, _ = librosa.beat.beat_track(y=y, sr=sr) chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr) rmse = librosa.feature.rms(y=y)[0] spec_cent = librosa.feature.spectral_centroid(y=y, sr=sr) spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr) if tempo > 120 and np.mean(rmse) > 0.2: return "energetic" elif tempo < 80 and np.mean(chroma_stft[0]) > 0.5: return "sad" elif np.mean(spec_cent) > 2000 and np.mean(spec_bw) > 1000: return "happy" else: return "calm" def detect_audio_language(audio_path): try: y, sr = librosa.load(audio_path) segment = y[:int(30 * sr)] with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: sf.write(temp_file.name, segment, sr) temp_filepath = temp_file.name with open(temp_filepath, "rb") as f: language = detect(f.read()) os.remove(temp_filepath) return language except Exception: return None def split_text_into_chunks(text, max_chunk_length=200): words = text.split() chunks = [] current_chunk = "" for word in words: if len(current_chunk) + len(word) + 1 <= max_chunk_length: current_chunk += word + " " else: chunks.append(current_chunk.strip()) current_chunk = word + " " chunks.append(current_chunk.strip()) return chunks async def save_audio_to_storage(audio_data, filename): with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file: torchaudio.save(temp_file.name, audio_data.unsqueeze(0), 24000) return temp_file.name async def generate_music_with_voice(description, melody_audio, voice_audio, duration, text_prompt, language): try: description = re.sub(r'http\S+', '', description) description = re.sub(r'[^a-zA-Z0-9\s]', '', description) duration = int(duration * 1000) if duration is not None else 30000 musicgen_model.set_generation_params(duration=duration) with torch.no_grad(): if description: if melody_audio: melody, sr = torchaudio.load(melody_audio, normalize=True) melody = melody.to(device) wav_music = musicgen_model.generate_with_chroma([description], melody[None].to(device), sr) detected_language = detect_audio_language(melody_audio) if detected_language: language = detected_language else: wav_music = musicgen_model.generate([description]) else: wav_music = musicgen_model.generate_unconditional(1) music_filename = await save_audio_to_storage(wav_music[0].cpu(), "music_" + str(uuid.uuid4()) + ".wav") if language not in supported_languages: raise ValueError(f"Language {language} not supported") if not text_prompt and not voice_audio: raise ValueError("Text prompt or voice audio is required") if text_prompt and len(text_prompt) > 1000: raise ValueError("Text prompt is too long, please keep it under 1000 characters") if text_prompt and language == "vi": text_prompt = normalize_vietnamese_text(text_prompt) speaker_wav = voice_audio if voice_audio else music_filename gpt_cond_latent, speaker_embedding = xtts_model.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60) emotion = analyze_music_for_emotion(music_filename if not voice_audio else voice_audio) prosody_strength = 1.0 speaking_rate = 1.0 if emotion == "energetic": prosody_strength = 1.2 speaking_rate = 1.1 elif emotion == "sad": prosody_strength = 0.8 speaking_rate = 0.9 elif emotion == "happy": prosody_strength = 1.1 speaking_rate = 1.05 voice_filename = None if voice_audio: voice_filename = voice_audio elif text_prompt: text_chunks = split_text_into_chunks(text_prompt) wav_chunks = [] for chunk in tqdm(text_chunks, desc="Synthesizing voice chunks"): out = xtts_model.inference( chunk, language, gpt_cond_latent, speaker_embedding, repetition_penalty=5.0, temperature=0.75, enable_text_splitting=True, prosody_strength=prosody_strength, speaking_rate=speaking_rate ) wav_chunks.append(torch.tensor(out["wav"])) final_wav = torch.cat(wav_chunks, dim=-1) voice_filename = await save_audio_to_storage(final_wav, "voice_" + str(uuid.uuid4()) + ".wav") return music_filename, voice_filename except IsADirectoryError: return "Error: Provided path is a directory, not a file.", "Error: Provided path is a directory, not a file." except Exception as e: return str(e), str(e) description = gr.Textbox(label="Description", placeholder="Acoustic, guitar, melody, trap, D minor, 90 bpm") melody_audio = gr.Audio(label="Melody Audio (optional)", type="filepath") voice_audio = gr.Audio(label="Voice Audio (optional)", type="filepath") duration = gr.Number(label="Duration (seconds)", value=None) text_prompt = gr.Textbox(label="Text Prompt (optional if voice audio is provided)", placeholder="Input text for TTS") language = gr.Dropdown(label="Language (auto-detected from audio if provided)", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi", "vi"], max_choices=1, value="en") output_music_path = gr.File(label="Generated Music") output_voice_path = gr.File(label="Generated Voice") iface = gr.Interface( fn=generate_music_with_voice, inputs=[description, melody_audio, voice_audio, duration, text_prompt, language], outputs=[output_music_path, output_voice_path], title="MusicGen with viXTTS", description="Generate music with the MusicGen model and synthesize a voice to match the rhythm using viXTTS.", examples=[], allow_flagging="never" ) iface.launch(share=True) app = FastAPI() @app.post("/synthesize") async def api_synthesize(prompt: str, language: str = "en", audio_file: UploadFile = File(...)): try: temp_audio_file = tempfile.NamedTemporaryFile(delete=False) temp_audio_file.write(audio_file.file.read()) temp_audio_file.close() music_output, voice_output = await generate_music_with_voice(prompt, None, temp_audio_file.name, None, None, language) return JSONResponse(content={"music_output": music_output, "voice_output": voice_output}) except Exception as e: return JSONResponse(content={"error": str(e)}) if __name__ == "__main__": app.run()