import os |
import logging |
from omegaconf import OmegaConf |
import torch |
from vocos import Vocos |
from .model.dvae import DVAE |
from .model.gpt import GPT_warpper |
from .utils.gpu_utils import select_device |
from .utils.io_utils import get_latest_modified_file |
from .infer.api import refine_text, infer_code |
from dataclasses import dataclass |
from typing import Literal, Optional, List, Tuple, Dict, Union |
import numpy as np |
from tool.logger import get_logger |
from tool.normalizer import normalizer_en_nemo_text, normalizer_cn_tn |
from tool.func import encode_prompt |
from ChatTTS.norm import Normalizer |
from huggingface_hub import snapshot_download |
class Chat: |
def __init__(self, ): |
self.pretrain_models = {} |
self.logger = get_logger(__name__, lv=logging.INFO) |
self.normalizer = Normalizer( |
os.path.join(os.path.dirname(__file__), "res", "homophones_map.json"), |
self.logger, |
) |
def check_model(self, level=logging.INFO, use_decoder=False): |
not_finish = False |
check_list = ['vocos', 'gpt', 'tokenizer'] |
if use_decoder: |
check_list.append('decoder') |
else: |
check_list.append('dvae') |
for module in check_list: |
if module not in self.pretrain_models: |
self.logger.log(logging.WARNING, f'{module} not initialized.') |
not_finish = True |
if not not_finish: |
self.logger.log(level, f'All initialized.') |
return not not_finish |
def load_models(self, source='huggingface', force_redownload=False, local_path='<LOCAL_PATH>'): |
if source == 'huggingface': |
hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) |
try: |
download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots')) |
except: |
download_path = None |
if download_path is None or force_redownload: |
self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS') |
download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"]) |
else: |
self.logger.log(logging.INFO, f'Load from cache: {download_path}') |
self._load(**{k: os.path.join(download_path, v) for k, v in |
OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}) |
self._regist_normalizer() |
elif source == 'local': |
self.logger.log(logging.INFO, f'Load from local: {local_path}') |
self._load(**{k: os.path.join(local_path, v) for k, v in |
OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()}) |
def _regist_normalizer(self): |
self.logger.info("==========开始注册 normalizer===========") |
try: |
self.normalizer.register("en", normalizer_en_nemo_text()) |
except ValueError as e: |
self.logger.error('normalizer_en_nemo_text register fail', e) |
except: |
self.logger.error("Package nemo_text_processing not found!") |
self.logger.error( |
"Run: conda install -c conda-forge pynini=2.1.5 && pip install nemo_text_processing", |
) |
try: |
self.normalizer.register("zh", normalizer_cn_tn()) |
except ValueError as e: |
self.logger.error('normalizer_cn_tn register fail', e) |
except: |
self.logger.error("Package WeTextProcessing not found!") |
self.logger.error( |
"Run: conda install -c conda-forge pynini=2.1.5 && pip install WeTextProcessing", |
) |
def _load( |
self, |
vocos_config_path: str = None, |
vocos_ckpt_path: str = None, |
dvae_config_path: str = None, |
dvae_ckpt_path: str = None, |
gpt_config_path: str = None, |
gpt_ckpt_path: str = None, |
decoder_config_path: str = None, |
decoder_ckpt_path: str = None, |
tokenizer_path: str = None, |
device: str = None |
): |
if not device: |
device = select_device(4096) |
self.logger.log(logging.INFO, f'use {device}') |
self.device = device |
if vocos_config_path: |
vocos = Vocos.from_hparams(vocos_config_path).to(device).eval() |
assert vocos_ckpt_path, 'vocos_ckpt_path should not be None' |
vocos.load_state_dict(torch.load(vocos_ckpt_path)) |
self.pretrain_models['vocos'] = vocos |
self.logger.log(logging.INFO, 'vocos loaded.') |
if dvae_config_path: |
cfg = OmegaConf.load(dvae_config_path) |
dvae = DVAE(**cfg).to(device).eval() |
assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' |
dvae.load_state_dict(torch.load(dvae_ckpt_path)) |
self.pretrain_models['dvae'] = dvae |
self.logger.log(logging.INFO, 'dvae loaded.') |
if gpt_config_path: |
cfg = OmegaConf.load(gpt_config_path) |
gpt = GPT_warpper(**cfg).to(device).eval() |
assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' |
gpt.load_state_dict(torch.load(gpt_ckpt_path)) |
self.pretrain_models['gpt'] = gpt |
self.gpt = gpt |
self.logger.log(logging.INFO, 'gpt loaded.') |
spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") |
assert os.path.exists( |
spk_stat_path |
), f"Missing spk_stat.pt: {spk_stat_path}" |
self.pretrain_models["spk_stat"] = torch.load( |
spk_stat_path, weights_only=True, mmap=True |
).to(device) |
if decoder_config_path: |
cfg = OmegaConf.load(decoder_config_path) |
decoder = DVAE(**cfg).to(device).eval() |
assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' |
decoder.load_state_dict(torch.load(decoder_ckpt_path)) |
self.pretrain_models['decoder'] = decoder |
self.logger.log(logging.INFO, 'decoder loaded.') |
if tokenizer_path: |
tokenizer = torch.load(tokenizer_path) |
tokenizer.padding_side = 'left' |
self.pretrain_models['tokenizer'] = tokenizer |
self.logger.log(logging.INFO, 'tokenizer loaded.') |
self.check_model() |
@dataclass(repr=False, eq=False) |
class RefineTextParams: |
prompt: str = "" |
top_P: float = 0.7 |
top_K: int = 20 |
temperature: float = 0.7 |
repetition_penalty: float = 1.0 |
max_new_token: int = 384 |
min_new_token: int = 0 |
show_tqdm: bool = True |
ensure_non_empty: bool = True |
@dataclass(repr=False, eq=False) |
class InferCodeParams(RefineTextParams): |
prompt: str = "[speed_5]" |
spk_emb: Optional[str] = None |
temperature: float = 0.3 |
repetition_penalty: float = 1.05 |
max_new_token: int = 2048 |
def infer( |
self, |
text, |
skip_refine_text=False, |
refine_text_only=False, |
params_refine_text={}, |
params_infer_code={}, |
use_decoder=False, |
lang=None |
): |
self.logger.info( |
f"========开始infer模型,use_decoder:{use_decoder},lang:{lang}," |
f"mskip_refine_text:{skip_refine_text},refine_text_only:{refine_text_only}======") |
assert self.check_model(use_decoder=use_decoder) |
if not isinstance(text, list): |
text = [text] |
text = [ |
self.normalizer( |
text=t, |
do_text_normalization=True, |
do_homophone_replacement=True, |
lang=lang, |
) |
for t in text |
] |
if skip_refine_text: |
self.logger.info(f"========对文本内容不做优化处理,仅做规则处理======") |
else: |
self.logger.info(f"========针对文本内容做模型优化处理,lang:{lang}======") |
text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] |
text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in |
text_tokens] |
text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) |
if refine_text_only: |
return text |
text = [params_infer_code.get('prompt', '') + i for i in text] |
params_infer_code.pop('prompt', '') |
result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder) |
if use_decoder: |
mel_spec = [self.pretrain_models['decoder'](i[None].permute(0, 2, 1)) for i in result['hiddens']] |
else: |
mel_spec = [self.pretrain_models['dvae'](i[None].permute(0, 2, 1)) for i in result['ids']] |
wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec] |
return wav |
def emptpy_audio(self): |
return self.infer(" ", |
skip_refine_text=True, |
refine_text_only=False, |
params_refine_text={}, |
params_infer_code={}, |
use_decoder=False) |
''' |
将音频张量 做转码处理 |
''' |
@torch.inference_mode() |
def sample_audio_speaker(self, wav: Union[np.ndarray, torch.Tensor]) -> str: |
if isinstance(wav, np.ndarray): |
wav = torch.from_numpy(wav).to(self.device) |
squeeze = self.pretrain_models['dvae'](wav, "encode").squeeze_(0) |
return encode_prompt(squeeze) |
def sample_random_speaker_tensor(self) -> torch.Tensor: |
with torch.no_grad(): |
dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features |
out: torch.Tensor = self.pretrain_models["spk_stat"] |
std, mean = out.chunk(2) |
spk = ( |
torch.randn(dim, device=std.device, dtype=torch.float16) |
.mul_(std) |
.add_(mean) |
) |
del out, std, mean |
return spk |