Yhhxhfh's picture
Update app.py
056b100 verified
raw
history blame contribute delete
No virus
9.02 kB
import os
import uuid
import torch
import re
import gradio as gr
import torchaudio
from audiocraft.models import MusicGen
from audiocraft.data.audio import audio_write
from huggingface_hub import hf_hub_download, snapshot_download
from TTS.tts.configs.xtts_config import XttsConfig
from TTS.tts.models.xtts import Xtts
from vinorm import TTSnorm
from fastapi import FastAPI, File, UploadFile
from fastapi.responses import FileResponse, JSONResponse
import librosa
import soundfile as sf
from langdetect import detect
import numpy as np
import tempfile
from pathlib import Path
from tqdm import tqdm
import asyncio
device = "cuda" if torch.cuda.is_available() else "cpu"
musicgen_model = MusicGen.get_pretrained("facebook/musicgen-small", device=device)
checkpoint_dir = "model/"
repo_id = "capleaf/viXTTS"
os.makedirs(checkpoint_dir, exist_ok=True)
required_files = ["model.pth", "config.json", "vocab.json", "speakers_xtts.pth"]
files_in_dir = os.listdir(checkpoint_dir)
if not all(file in files_in_dir for file in required_files):
snapshot_download(repo_id=repo_id, repo_type="model", local_dir=checkpoint_dir)
hf_hub_download(repo_id="coqui/XTTS-v2", filename="speakers_xtts.pth", local_dir=checkpoint_dir)
xtts_config = os.path.join(checkpoint_dir, "config.json")
config = XttsConfig()
config.load_json(xtts_config)
xtts_model = Xtts.init_from_config(config).to(device)
xtts_model.load_checkpoint(config, checkpoint_dir=checkpoint_dir, use_deepspeed=False)
supported_languages = config.languages
if "vi" not in supported_languages:
supported_languages.append("vi")
def normalize_vietnamese_text(text):
return TTSnorm(text, unknown=False, lower=False, rule=True).replace("..", ".").replace("!.", "!").replace("?.", "?").replace(" .", ".").replace(" ,", ",").replace('"', "").replace("'", "").replace("AI", "Ây Ai").replace("A.I", "Ây Ai")
def analyze_music_for_emotion(music_path):
y, sr = librosa.load(music_path)
tempo, _ = librosa.beat.beat_track(y=y, sr=sr)
chroma_stft = librosa.feature.chroma_stft(y=y, sr=sr)
rmse = librosa.feature.rms(y=y)[0]
spec_cent = librosa.feature.spectral_centroid(y=y, sr=sr)
spec_bw = librosa.feature.spectral_bandwidth(y=y, sr=sr)
if tempo > 120 and np.mean(rmse) > 0.2:
return "energetic"
elif tempo < 80 and np.mean(chroma_stft[0]) > 0.5:
return "sad"
elif np.mean(spec_cent) > 2000 and np.mean(spec_bw) > 1000:
return "happy"
else:
return "calm"
def detect_audio_language(audio_path):
try:
y, sr = librosa.load(audio_path)
segment = y[:int(30 * sr)]
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
sf.write(temp_file.name, segment, sr)
temp_filepath = temp_file.name
with open(temp_filepath, "rb") as f:
language = detect(f.read())
os.remove(temp_filepath)
return language
except Exception:
return None
def split_text_into_chunks(text, max_chunk_length=200):
words = text.split()
chunks = []
current_chunk = ""
for word in words:
if len(current_chunk) + len(word) + 1 <= max_chunk_length:
current_chunk += word + " "
else:
chunks.append(current_chunk.strip())
current_chunk = word + " "
chunks.append(current_chunk.strip())
return chunks
async def save_audio_to_storage(audio_data, filename):
with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as temp_file:
torchaudio.save(temp_file.name, audio_data.unsqueeze(0), 24000)
return temp_file.name
async def generate_music_with_voice(description, melody_audio, voice_audio, duration, text_prompt, language):
try:
description = re.sub(r'http\S+', '', description)
description = re.sub(r'[^a-zA-Z0-9\s]', '', description)
duration = int(duration * 1000) if duration is not None else 30000
musicgen_model.set_generation_params(duration=duration)
with torch.no_grad():
if description:
if melody_audio:
melody, sr = torchaudio.load(melody_audio, normalize=True)
melody = melody.to(device)
wav_music = musicgen_model.generate_with_chroma([description], melody[None].to(device), sr)
detected_language = detect_audio_language(melody_audio)
if detected_language:
language = detected_language
else:
wav_music = musicgen_model.generate([description])
else:
wav_music = musicgen_model.generate_unconditional(1)
music_filename = await save_audio_to_storage(wav_music[0].cpu(), "music_" + str(uuid.uuid4()) + ".wav")
if language not in supported_languages:
raise ValueError(f"Language {language} not supported")
if not text_prompt and not voice_audio:
raise ValueError("Text prompt or voice audio is required")
if text_prompt and len(text_prompt) > 1000:
raise ValueError("Text prompt is too long, please keep it under 1000 characters")
if text_prompt and language == "vi":
text_prompt = normalize_vietnamese_text(text_prompt)
speaker_wav = voice_audio if voice_audio else music_filename
gpt_cond_latent, speaker_embedding = xtts_model.get_conditioning_latents(audio_path=speaker_wav, gpt_cond_len=30, gpt_cond_chunk_len=4, max_ref_length=60)
emotion = analyze_music_for_emotion(music_filename if not voice_audio else voice_audio)
prosody_strength = 1.0
speaking_rate = 1.0
if emotion == "energetic":
prosody_strength = 1.2
speaking_rate = 1.1
elif emotion == "sad":
prosody_strength = 0.8
speaking_rate = 0.9
elif emotion == "happy":
prosody_strength = 1.1
speaking_rate = 1.05
voice_filename = None
if voice_audio:
voice_filename = voice_audio
elif text_prompt:
text_chunks = split_text_into_chunks(text_prompt)
wav_chunks = []
for chunk in tqdm(text_chunks, desc="Synthesizing voice chunks"):
out = xtts_model.inference(
chunk,
language,
gpt_cond_latent,
speaker_embedding,
repetition_penalty=5.0,
temperature=0.75,
enable_text_splitting=True,
prosody_strength=prosody_strength,
speaking_rate=speaking_rate
)
wav_chunks.append(torch.tensor(out["wav"]))
final_wav = torch.cat(wav_chunks, dim=-1)
voice_filename = await save_audio_to_storage(final_wav, "voice_" + str(uuid.uuid4()) + ".wav")
return music_filename, voice_filename
except IsADirectoryError:
return "Error: Provided path is a directory, not a file.", "Error: Provided path is a directory, not a file."
except Exception as e:
return str(e), str(e)
description = gr.Textbox(label="Description", placeholder="Acoustic, guitar, melody, trap, D minor, 90 bpm")
melody_audio = gr.Audio(label="Melody Audio (optional)", type="filepath")
voice_audio = gr.Audio(label="Voice Audio (optional)", type="filepath")
duration = gr.Number(label="Duration (seconds)", value=None)
text_prompt = gr.Textbox(label="Text Prompt (optional if voice audio is provided)", placeholder="Input text for TTS")
language = gr.Dropdown(label="Language (auto-detected from audio if provided)", choices=["en", "es", "fr", "de", "it", "pt", "pl", "tr", "ru", "nl", "cs", "ar", "zh-cn", "ja", "ko", "hu", "hi", "vi"], max_choices=1, value="en")
output_music_path = gr.File(label="Generated Music")
output_voice_path = gr.File(label="Generated Voice")
iface = gr.Interface(
fn=generate_music_with_voice,
inputs=[description, melody_audio, voice_audio, duration, text_prompt, language],
outputs=[output_music_path, output_voice_path],
title="MusicGen with viXTTS",
description="Generate music with the MusicGen model and synthesize a voice to match the rhythm using viXTTS.",
examples=[],
allow_flagging="never"
)
iface.launch(share=True)
app = FastAPI()
@app.post("/synthesize")
async def api_synthesize(prompt: str, language: str = "en", audio_file: UploadFile = File(...)):
try:
temp_audio_file = tempfile.NamedTemporaryFile(delete=False)
temp_audio_file.write(audio_file.file.read())
temp_audio_file.close()
music_output, voice_output = await generate_music_with_voice(prompt, None, temp_audio_file.name, None, None, language)
return JSONResponse(content={"music_output": music_output, "voice_output": voice_output})
except Exception as e:
return JSONResponse(content={"error": str(e)})
if __name__ == "__main__":
app.run()