|
import asyncio |
|
import datetime |
|
import logging |
|
import os |
|
import time |
|
import traceback |
|
import tempfile |
|
from concurrent.futures import ThreadPoolExecutor |
|
from torch.nn.utils.parametrizations import weight_norm |
|
from scipy.io import wavfile |
|
import numpy as np |
|
import traceback |
|
import librosa |
|
import torch |
|
from fairseq import checkpoint_utils |
|
import uuid |
|
|
|
from config import Config |
|
from lib.infer_pack.models import ( |
|
SynthesizerTrnMs256NSFsid, |
|
SynthesizerTrnMs256NSFsid_nono, |
|
SynthesizerTrnMs768NSFsid, |
|
SynthesizerTrnMs768NSFsid_nono, |
|
) |
|
from rmvpe import RMVPE |
|
from vc_infer_pipeline import VC |
|
|
|
model_cache = {} |
|
|
|
logger = logging.getLogger('voice_processing') |
|
|
|
def load_model(model_name): |
|
""" |
|
Loads an RVC model with proper error handling and logging. |
|
|
|
Args: |
|
model_name (str): Name of the model to load (e.g., 'mongolian7-male') |
|
|
|
Returns: |
|
tuple: (model, config) or None if loading fails |
|
""" |
|
try: |
|
logger.info(f"Loading model: {model_name}") |
|
|
|
|
|
model_dir = "weights" |
|
model_path = os.path.join(model_dir, model_name) |
|
|
|
|
|
pth_files = [f for f in os.listdir(model_path) if f.endswith('.pth')] |
|
if not pth_files: |
|
logger.error(f"No .pth file found in {model_path}") |
|
return None |
|
|
|
pth_path = os.path.join(model_path, pth_files[0]) |
|
logger.info(f"Found model file: {pth_path}") |
|
|
|
|
|
cpt = torch.load(pth_path, map_location="cpu", weights_only=True) |
|
logger.info("Model weights loaded successfully") |
|
|
|
|
|
tgt_sr = cpt["config"][-1] |
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] |
|
if_f0 = cpt.get("f0", 1) |
|
version = cpt.get("version", "v1") |
|
|
|
logger.info(f"Model config: sr={tgt_sr}, if_f0={if_f0}, version={version}") |
|
|
|
|
|
if version == "v1": |
|
from lib.infer_pack.models import SynthesizerTrnMs256NSFsid |
|
model = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=False) |
|
else: |
|
from lib.infer_pack.models import SynthesizerTrnMs768NSFsid |
|
model = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=False) |
|
|
|
|
|
model.eval() |
|
model.load_state_dict(cpt["weight"], strict=False) |
|
|
|
logger.info("Model initialized successfully") |
|
return model |
|
|
|
except Exception as e: |
|
logger.error(f"Error loading model {model_name}: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
return None |
|
|
|
def process_audio(model, audio_file, logger, index_rate=0, use_uploaded_voice=True, uploaded_voice=None): |
|
"""Process audio through the model""" |
|
try: |
|
logger.info("Starting audio processing") |
|
|
|
if model is None: |
|
logger.error("No model provided for processing") |
|
return None |
|
|
|
|
|
sr, audio = wavfile.read(audio_file) |
|
logger.info(f"Loaded audio: sr={sr}Hz, shape={audio.shape}") |
|
|
|
|
|
if len(audio.shape) > 1: |
|
audio = np.mean(audio, axis=1) |
|
audio = audio.astype(np.float32) |
|
|
|
|
|
input_tensor = torch.FloatTensor(audio) |
|
if torch.cuda.is_available(): |
|
input_tensor = input_tensor.cuda() |
|
model = model.cuda() |
|
|
|
|
|
with torch.no_grad(): |
|
|
|
phone = input_tensor.unsqueeze(0) |
|
phone_lengths = torch.LongTensor([len(input_tensor)]).to(input_tensor.device) |
|
pitch = torch.zeros(1, len(input_tensor)).to(input_tensor.device) |
|
nsff0 = torch.zeros_like(pitch).to(input_tensor.device) |
|
sid = torch.LongTensor([0]).to(input_tensor.device) |
|
|
|
|
|
output = model.infer( |
|
phone=phone, |
|
phone_lengths=phone_lengths, |
|
pitch=pitch, |
|
nsff0=nsff0, |
|
sid=sid |
|
) |
|
|
|
if torch.cuda.is_available(): |
|
output = output.cpu() |
|
output = output.numpy() |
|
|
|
logger.info(f"Processing complete, output shape: {output.shape}") |
|
return (None, None, (sr, output)) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing audio: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
return None |
|
|
|
|
|
logging.getLogger("fairseq").setLevel(logging.WARNING) |
|
logging.getLogger("numba").setLevel(logging.WARNING) |
|
logging.getLogger("markdown_it").setLevel(logging.WARNING) |
|
logging.getLogger("urllib3").setLevel(logging.WARNING) |
|
logging.getLogger("matplotlib").setLevel(logging.WARNING) |
|
|
|
limitation = os.getenv("SYSTEM") == "spaces" |
|
|
|
config = Config() |
|
|
|
|
|
|
|
|
|
|
|
|
|
model_root = "weights" |
|
models = [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")] |
|
models.sort() |
|
|
|
def get_unique_filename(extension): |
|
return f"{uuid.uuid4()}.{extension}" |
|
|
|
def model_data(model_name): |
|
pth_path = [ |
|
f"{model_root}/{model_name}/{f}" |
|
for f in os.listdir(f"{model_root}/{model_name}") |
|
if f.endswith(".pth") |
|
][0] |
|
print(f"Loading {pth_path}") |
|
|
|
cpt = torch.load(pth_path, map_location="cpu", weights_only=True) |
|
tgt_sr = cpt["config"][-1] |
|
cpt["config"][-3] = cpt["weight"]["emb_g.weight"].shape[0] |
|
if_f0 = cpt.get("f0", 1) |
|
version = cpt.get("version", "v1") |
|
if version == "v1": |
|
if if_f0 == 1: |
|
net_g = SynthesizerTrnMs256NSFsid(*cpt["config"], is_half=config.is_half) |
|
else: |
|
net_g = SynthesizerTrnMs256NSFsid_nono(*cpt["config"]) |
|
elif version == "v2": |
|
if if_f0 == 1: |
|
net_g = SynthesizerTrnMs768NSFsid(*cpt["config"], is_half=config.is_half) |
|
else: |
|
net_g = SynthesizerTrnMs768NSFsid_nono(*cpt["config"]) |
|
else: |
|
raise ValueError("Unknown version") |
|
del net_g.enc_q |
|
net_g.load_state_dict(cpt["weight"], strict=False) |
|
print("Model loaded") |
|
net_g.eval().to(config.device) |
|
if config.is_half: |
|
net_g = net_g.half() |
|
else: |
|
net_g = net_g.float() |
|
vc = VC(tgt_sr, config) |
|
|
|
index_files = [ |
|
f"{model_root}/{model_name}/{f}" |
|
for f in os.listdir(f"{model_root}/{model_name}") |
|
if f.endswith(".index") |
|
] |
|
if len(index_files) == 0: |
|
print("No index file found") |
|
index_file = "" |
|
else: |
|
index_file = index_files[0] |
|
print(f"Index file found: {index_file}") |
|
|
|
return tgt_sr, net_g, vc, version, index_file, if_f0 |
|
|
|
def load_hubert(): |
|
models, _, _ = checkpoint_utils.load_model_ensemble_and_task( |
|
["hubert_base.pt"], |
|
suffix="", |
|
) |
|
hubert_model = models[0] |
|
hubert_model = hubert_model.to(config.device) |
|
if config.is_half: |
|
hubert_model = hubert_model.half() |
|
else: |
|
hubert_model = hubert_model.float() |
|
return hubert_model.eval() |
|
|
|
def get_model_names(): |
|
return [d for d in os.listdir(model_root) if os.path.isdir(f"{model_root}/{d}")] |
|
|
|
|
|
hubert_model = load_hubert() |
|
rmvpe_model = RMVPE("rmvpe.pt", config.is_half, config.device) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def run_async_in_thread(fn, *args): |
|
loop = asyncio.new_event_loop() |
|
asyncio.set_event_loop(loop) |
|
result = loop.run_until_complete(fn(*args)) |
|
loop.close() |
|
return result |
|
|
|
def parallel_tts(tasks): |
|
"""Process multiple TTS tasks""" |
|
logger.info(f"Received {len(tasks)} tasks for processing") |
|
results = [] |
|
|
|
for i, task in enumerate(tasks): |
|
try: |
|
logger.info(f"Processing task {i+1}/{len(tasks)}") |
|
|
|
model_name, _, _, slang_rate, use_uploaded_voice, audio_file = task |
|
logger.info(f"Model: {model_name}, Slang rate: {slang_rate}") |
|
|
|
model = load_model(model_name) |
|
if model is None: |
|
logger.error(f"Failed to load model {model_name}") |
|
results.append(None) |
|
continue |
|
|
|
result = process_audio( |
|
model=model, |
|
audio_file=audio_file, |
|
logger=logger, |
|
index_rate=0, |
|
use_uploaded_voice=use_uploaded_voice, |
|
uploaded_voice=None |
|
) |
|
logger.info(f"Processing completed for task {i+1}") |
|
|
|
results.append(result) |
|
|
|
except Exception as e: |
|
logger.error(f"Error processing task {i+1}: {str(e)}") |
|
logger.error(traceback.format_exc()) |
|
results.append(None) |
|
|
|
return results |
|
|
|
|