|
import os |
|
import warnings |
|
from pathlib import Path |
|
from typing import Optional, Union |
|
|
|
import gradio as gr |
|
import numpy as np |
|
|
|
import torch |
|
from gradio.processing_utils import convert_to_16_bit_wav |
|
|
|
import utils |
|
from infer import get_net_g, infer |
|
from models import SynthesizerTrn |
|
from models_jp_extra import SynthesizerTrn as SynthesizerTrnJPExtra |
|
|
|
from .constants import ( |
|
DEFAULT_ASSIST_TEXT_WEIGHT, |
|
DEFAULT_LENGTH, |
|
DEFAULT_LINE_SPLIT, |
|
DEFAULT_NOISE, |
|
DEFAULT_NOISEW, |
|
DEFAULT_SDP_RATIO, |
|
DEFAULT_SPLIT_INTERVAL, |
|
DEFAULT_STYLE, |
|
DEFAULT_STYLE_WEIGHT, |
|
) |
|
from .log import logger |
|
|
|
|
|
def adjust_voice(fs, wave, pitch_scale, intonation_scale): |
|
if pitch_scale == 1.0 and intonation_scale == 1.0: |
|
|
|
return fs, wave |
|
|
|
try: |
|
import pyworld |
|
except ImportError: |
|
raise ImportError( |
|
"pyworld is not installed. Please install it by `pip install pyworld`" |
|
) |
|
|
|
|
|
|
|
|
|
wave = wave.astype(np.double) |
|
f0, t = pyworld.harvest(wave, fs) |
|
|
|
|
|
sp = pyworld.cheaptrick(wave, f0, t, fs) |
|
ap = pyworld.d4c(wave, f0, t, fs) |
|
|
|
non_zero_f0 = [f for f in f0 if f != 0] |
|
f0_mean = sum(non_zero_f0) / len(non_zero_f0) |
|
|
|
for i, f in enumerate(f0): |
|
if f == 0: |
|
continue |
|
f0[i] = pitch_scale * f0_mean + intonation_scale * (f - f0_mean) |
|
|
|
wave = pyworld.synthesize(f0, sp, ap, fs) |
|
return fs, wave |
|
|
|
|
|
class Model: |
|
def __init__( |
|
self, model_path: Path, config_path: Path, style_vec_path: Path, device: str |
|
): |
|
self.model_path: Path = model_path |
|
self.config_path: Path = config_path |
|
self.style_vec_path: Path = style_vec_path |
|
self.device: str = device |
|
self.hps: utils.HParams = utils.get_hparams_from_file(self.config_path) |
|
self.spk2id: dict[str, int] = self.hps.data.spk2id |
|
self.id2spk: dict[int, str] = {v: k for k, v in self.spk2id.items()} |
|
|
|
self.num_styles: int = self.hps.data.num_styles |
|
if hasattr(self.hps.data, "style2id"): |
|
self.style2id: dict[str, int] = self.hps.data.style2id |
|
else: |
|
self.style2id: dict[str, int] = {str(i): i for i in range(self.num_styles)} |
|
if len(self.style2id) != self.num_styles: |
|
raise ValueError( |
|
f"Number of styles ({self.num_styles}) does not match the number of style2id ({len(self.style2id)})" |
|
) |
|
|
|
self.style_vectors: np.ndarray = np.load(self.style_vec_path) |
|
if self.style_vectors.shape[0] != self.num_styles: |
|
raise ValueError( |
|
f"The number of styles ({self.num_styles}) does not match the number of style vectors ({self.style_vectors.shape[0]})" |
|
) |
|
|
|
self.net_g: Union[SynthesizerTrn, SynthesizerTrnJPExtra, None] = None |
|
|
|
def load_net_g(self): |
|
self.net_g = get_net_g( |
|
model_path=str(self.model_path), |
|
version=self.hps.version, |
|
device=self.device, |
|
hps=self.hps, |
|
) |
|
|
|
def get_style_vector(self, style_id: int, weight: float = 1.0) -> np.ndarray: |
|
mean = self.style_vectors[0] |
|
style_vec = self.style_vectors[style_id] |
|
style_vec = mean + (style_vec - mean) * weight |
|
return style_vec |
|
|
|
def get_style_vector_from_audio( |
|
self, audio_path: str, weight: float = 1.0 |
|
) -> np.ndarray: |
|
from style_gen import get_style_vector |
|
|
|
xvec = get_style_vector(audio_path) |
|
mean = self.style_vectors[0] |
|
xvec = mean + (xvec - mean) * weight |
|
return xvec |
|
|
|
def infer( |
|
self, |
|
text: str, |
|
language: str = "JP", |
|
sid: int = 0, |
|
reference_audio_path: Optional[str] = None, |
|
sdp_ratio: float = DEFAULT_SDP_RATIO, |
|
noise: float = DEFAULT_NOISE, |
|
noisew: float = DEFAULT_NOISEW, |
|
length: float = DEFAULT_LENGTH, |
|
line_split: bool = DEFAULT_LINE_SPLIT, |
|
split_interval: float = DEFAULT_SPLIT_INTERVAL, |
|
assist_text: Optional[str] = None, |
|
assist_text_weight: float = DEFAULT_ASSIST_TEXT_WEIGHT, |
|
use_assist_text: bool = False, |
|
style: str = DEFAULT_STYLE, |
|
style_weight: float = DEFAULT_STYLE_WEIGHT, |
|
given_tone: Optional[list[int]] = None, |
|
pitch_scale: float = 1.0, |
|
intonation_scale: float = 1.0, |
|
ignore_unknown: bool = False, |
|
) -> tuple[int, np.ndarray]: |
|
logger.info(f"Start generating audio data from text:\n{text}") |
|
if language != "JP" and self.hps.version.endswith("JP-Extra"): |
|
raise ValueError( |
|
"The model is trained with JP-Extra, but the language is not JP" |
|
) |
|
if reference_audio_path == "": |
|
reference_audio_path = None |
|
if assist_text == "" or not use_assist_text: |
|
assist_text = None |
|
|
|
if self.net_g is None: |
|
self.load_net_g() |
|
if reference_audio_path is None: |
|
style_id = self.style2id[style] |
|
style_vector = self.get_style_vector(style_id, style_weight) |
|
else: |
|
style_vector = self.get_style_vector_from_audio( |
|
reference_audio_path, style_weight |
|
) |
|
if not line_split: |
|
with torch.no_grad(): |
|
audio = infer( |
|
text=text, |
|
sdp_ratio=sdp_ratio, |
|
noise_scale=noise, |
|
noise_scale_w=noisew, |
|
length_scale=length, |
|
sid=sid, |
|
language=language, |
|
hps=self.hps, |
|
net_g=self.net_g, |
|
device=self.device, |
|
assist_text=assist_text, |
|
assist_text_weight=assist_text_weight, |
|
style_vec=style_vector, |
|
given_tone=given_tone, |
|
ignore_unknown=ignore_unknown, |
|
) |
|
else: |
|
texts = text.split("\n") |
|
texts = [t for t in texts if t != ""] |
|
audios = [] |
|
with torch.no_grad(): |
|
for i, t in enumerate(texts): |
|
audios.append( |
|
infer( |
|
text=t, |
|
sdp_ratio=sdp_ratio, |
|
noise_scale=noise, |
|
noise_scale_w=noisew, |
|
length_scale=length, |
|
sid=sid, |
|
language=language, |
|
hps=self.hps, |
|
net_g=self.net_g, |
|
device=self.device, |
|
assist_text=assist_text, |
|
assist_text_weight=assist_text_weight, |
|
style_vec=style_vector, |
|
ignore_unknown=ignore_unknown, |
|
) |
|
) |
|
if i != len(texts) - 1: |
|
audios.append(np.zeros(int(44100 * split_interval))) |
|
audio = np.concatenate(audios) |
|
logger.info("Audio data generated successfully") |
|
if not (pitch_scale == 1.0 and intonation_scale == 1.0): |
|
_, audio = adjust_voice( |
|
fs=self.hps.data.sampling_rate, |
|
wave=audio, |
|
pitch_scale=pitch_scale, |
|
intonation_scale=intonation_scale, |
|
) |
|
with warnings.catch_warnings(): |
|
warnings.simplefilter("ignore") |
|
audio = convert_to_16_bit_wav(audio) |
|
return (self.hps.data.sampling_rate, audio) |
|
|
|
|
|
class ModelHolder: |
|
def __init__(self, root_dir: Path, device: str): |
|
self.root_dir: Path = root_dir |
|
self.device: str = device |
|
self.model_files_dict: dict[str, list[Path]] = {} |
|
self.current_model: Optional[Model] = None |
|
self.model_names: list[str] = [] |
|
self.models: list[Model] = [] |
|
self.refresh() |
|
|
|
def refresh(self): |
|
self.model_files_dict = {} |
|
self.model_names = [] |
|
self.current_model = None |
|
|
|
model_dirs = [d for d in self.root_dir.iterdir() if d.is_dir()] |
|
for model_dir in model_dirs: |
|
model_files = [ |
|
f |
|
for f in model_dir.iterdir() |
|
if f.suffix in [".pth", ".pt", ".safetensors"] |
|
] |
|
if len(model_files) == 0: |
|
logger.warning(f"No model files found in {model_dir}, so skip it") |
|
continue |
|
config_path = model_dir / "config.json" |
|
if not config_path.exists(): |
|
logger.warning( |
|
f"Config file {config_path} not found, so skip {model_dir}" |
|
) |
|
continue |
|
self.model_files_dict[model_dir.name] = model_files |
|
self.model_names.append(model_dir.name) |
|
|
|
def models_info(self): |
|
if hasattr(self, "_models_info"): |
|
return self._models_info |
|
result = [] |
|
for name, files in self.model_files_dict.items(): |
|
|
|
config_path = self.root_dir / name / "config.json" |
|
hps = utils.get_hparams_from_file(config_path) |
|
style2id: dict[str, int] = hps.data.style2id |
|
styles = list(style2id.keys()) |
|
result.append( |
|
{ |
|
"name": name, |
|
"files": [str(f) for f in files], |
|
"styles": styles, |
|
} |
|
) |
|
self._models_info = result |
|
return result |
|
|
|
def load_model(self, model_name: str, model_path_str: str): |
|
model_path = Path(model_path_str) |
|
if model_name not in self.model_files_dict: |
|
raise ValueError(f"Model `{model_name}` is not found") |
|
if model_path not in self.model_files_dict[model_name]: |
|
raise ValueError(f"Model file `{model_path}` is not found") |
|
if self.current_model is None or self.current_model.model_path != model_path: |
|
self.current_model = Model( |
|
model_path=model_path, |
|
config_path=self.root_dir / model_name / "config.json", |
|
style_vec_path=self.root_dir / model_name / "style_vectors.npy", |
|
device=self.device, |
|
) |
|
return self.current_model |
|
|
|
def load_model_gr( |
|
self, model_name: str, model_path_str: str |
|
) -> tuple[gr.Dropdown, gr.Button, gr.Dropdown]: |
|
model_path = Path(model_path_str) |
|
if model_name not in self.model_files_dict: |
|
raise ValueError(f"Model `{model_name}` is not found") |
|
if model_path not in self.model_files_dict[model_name]: |
|
raise ValueError(f"Model file `{model_path}` is not found") |
|
if ( |
|
self.current_model is not None |
|
and self.current_model.model_path == model_path |
|
): |
|
|
|
speakers = list(self.current_model.spk2id.keys()) |
|
styles = list(self.current_model.style2id.keys()) |
|
return ( |
|
gr.Dropdown(choices=styles, value=styles[0]), |
|
gr.Button(interactive=True, value="音声合成"), |
|
gr.Dropdown(choices=speakers, value=speakers[0]), |
|
) |
|
self.current_model = Model( |
|
model_path=model_path, |
|
config_path=self.root_dir / model_name / "config.json", |
|
style_vec_path=self.root_dir / model_name / "style_vectors.npy", |
|
device=self.device, |
|
) |
|
speakers = list(self.current_model.spk2id.keys()) |
|
styles = list(self.current_model.style2id.keys()) |
|
return ( |
|
gr.Dropdown(choices=styles, value=styles[0]), |
|
gr.Button(interactive=True, value="音声合成"), |
|
gr.Dropdown(choices=speakers, value=speakers[0]), |
|
) |
|
|
|
def update_model_files_gr(self, model_name: str) -> gr.Dropdown: |
|
model_files = self.model_files_dict[model_name] |
|
return gr.Dropdown(choices=model_files, value=model_files[0]) |
|
|
|
def update_model_names_gr(self) -> tuple[gr.Dropdown, gr.Dropdown, gr.Button]: |
|
self.refresh() |
|
initial_model_name = self.model_names[0] |
|
initial_model_files = self.model_files_dict[initial_model_name] |
|
return ( |
|
gr.Dropdown(choices=self.model_names, value=initial_model_name), |
|
gr.Dropdown(choices=initial_model_files, value=initial_model_files[0]), |
|
gr.Button(interactive=False), |
|
) |
|
|