import os import logging from omegaconf import OmegaConf import torch from vocos import Vocos from .model.dvae import DVAE from .model.gpt import GPT_warpper from .utils.gpu_utils import select_device from .utils.io_utils import get_latest_modified_file from .infer.api import refine_text, infer_code from dataclasses import dataclass from typing import Literal, Optional, List, Tuple, Dict import numpy as np import pybase16384 as b14 import lzma from huggingface_hub import snapshot_download logging.basicConfig(level = logging.INFO) class Chat: def __init__(self, ): self.pretrain_models = {} self.logger = logging.getLogger(__name__) self.gpt=None def check_model(self, level = logging.INFO, use_decoder = False): not_finish = False check_list = ['vocos', 'gpt', 'tokenizer'] if use_decoder: check_list.append('decoder') else: check_list.append('dvae') for module in check_list: if module not in self.pretrain_models: self.logger.log(logging.WARNING, f'{module} not initialized.') not_finish = True if not not_finish: self.logger.log(level, f'All initialized.') return not not_finish def load_models(self, source='huggingface', force_redownload=False, local_path=''): if source == 'huggingface': hf_home = os.getenv('HF_HOME', os.path.expanduser("~/.cache/huggingface")) try: download_path = get_latest_modified_file(os.path.join(hf_home, 'hub/models--2Noise--ChatTTS/snapshots')) except: download_path = None if download_path is None or force_redownload: self.logger.log(logging.INFO, f'Download from HF: https://huggingface.co/2Noise/ChatTTS') download_path = snapshot_download(repo_id="2Noise/ChatTTS", allow_patterns=["*.pt", "*.yaml"]) else: self.logger.log(logging.INFO, f'Load from cache: {download_path}') self._load(**{k: os.path.join(download_path, v) for k, v in OmegaConf.load(os.path.join(download_path, 'config', 'path.yaml')).items()}) elif source == 'local': self.logger.log(logging.INFO, f'Load from local: {local_path}') self._load(**{k: os.path.join(local_path, v) for k, v in OmegaConf.load(os.path.join(local_path, 'config', 'path.yaml')).items()}) def _load( self, vocos_config_path: str = None, vocos_ckpt_path: str = None, dvae_config_path: str = None, dvae_ckpt_path: str = None, gpt_config_path: str = None, gpt_ckpt_path: str = None, decoder_config_path: str = None, decoder_ckpt_path: str = None, tokenizer_path: str = None, device: str = None ): if not device: device = select_device(4096) self.logger.log(logging.INFO, f'use {device}') if vocos_config_path: vocos = Vocos.from_hparams(vocos_config_path).to(device).eval() assert vocos_ckpt_path, 'vocos_ckpt_path should not be None' vocos.load_state_dict(torch.load(vocos_ckpt_path)) self.pretrain_models['vocos'] = vocos self.logger.log(logging.INFO, 'vocos loaded.') if dvae_config_path: cfg = OmegaConf.load(dvae_config_path) dvae = DVAE(**cfg).to(device).eval() assert dvae_ckpt_path, 'dvae_ckpt_path should not be None' dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location='cpu')) self.pretrain_models['dvae'] = dvae self.logger.log(logging.INFO, 'dvae loaded.') if gpt_config_path: cfg = OmegaConf.load(gpt_config_path) gpt = GPT_warpper(**cfg).to(device).eval() assert gpt_ckpt_path, 'gpt_ckpt_path should not be None' gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location='cpu')) self.pretrain_models['gpt'] = gpt self.gpt = gpt self.logger.log(logging.INFO, 'gpt loaded.') spk_stat_path = os.path.join(os.path.dirname(gpt_ckpt_path), "spk_stat.pt") assert os.path.exists( spk_stat_path ), f"Missing spk_stat.pt: {spk_stat_path}" self.pretrain_models["spk_stat"] = torch.load( spk_stat_path, weights_only=True, mmap=True, map_location='cpu' ).to(device) if decoder_config_path: cfg = OmegaConf.load(decoder_config_path) decoder = DVAE(**cfg).to(device).eval() assert decoder_ckpt_path, 'decoder_ckpt_path should not be None' decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location='cpu')) self.pretrain_models['decoder'] = decoder self.logger.log(logging.INFO, 'decoder loaded.') if tokenizer_path: tokenizer = torch.load(tokenizer_path, map_location='cpu') tokenizer.padding_side = 'left' self.pretrain_models['tokenizer'] = tokenizer self.logger.log(logging.INFO, 'tokenizer loaded.') self.check_model() @dataclass(repr=False, eq=False) class RefineTextParams: prompt: str = "" top_P: float = 0.7 top_K: int = 20 temperature: float = 0.7 repetition_penalty: float = 1.0 max_new_token: int = 384 min_new_token: int = 0 show_tqdm: bool = True ensure_non_empty: bool = True @dataclass(repr=False, eq=False) class InferCodeParams(RefineTextParams): prompt: str = "[speed_5]" spk_emb: Optional[str] = None temperature: float = 0.3 repetition_penalty: float = 1.05 max_new_token: int = 2048 def infer( self, text, skip_refine_text=False, refine_text_only=False, params_refine_text={}, params_infer_code={}, use_decoder=False ): assert self.check_model(use_decoder=use_decoder) if not skip_refine_text: text_tokens = refine_text(self.pretrain_models, text, **params_refine_text)['ids'] text_tokens = [i[i < self.pretrain_models['tokenizer'].convert_tokens_to_ids('[break_0]')] for i in text_tokens] text = self.pretrain_models['tokenizer'].batch_decode(text_tokens) if refine_text_only: return text text = [params_infer_code.get('prompt', '') + i for i in text] params_infer_code.pop('prompt', '') result = infer_code(self.pretrain_models, text, **params_infer_code, return_hidden=use_decoder) if use_decoder: mel_spec = [self.pretrain_models['decoder'](i[None].permute(0,2,1)) for i in result['hiddens']] else: mel_spec = [self.pretrain_models['dvae'](i[None].permute(0,2,1)) for i in result['ids']] wav = [self.pretrain_models['vocos'].decode(i).cpu().numpy() for i in mel_spec] return wav def sample_random_speaker(self) -> str: return self._encode_spk_emb(self._sample_random_speaker()) @staticmethod def _encode_spk_emb(spk_emb: torch.Tensor) -> str: with torch.no_grad(): arr: np.ndarray = spk_emb.to(dtype=torch.float16, device="cpu").numpy() s = b14.encode_to_string( lzma.compress( arr.tobytes(), format=lzma.FORMAT_RAW, filters=[ {"id": lzma.FILTER_LZMA2, "preset": 9 | lzma.PRESET_EXTREME} ], ), ) del arr return s def _sample_random_speaker(self) -> torch.Tensor: with torch.no_grad(): dim: int = self.gpt.gpt.layers[0].mlp.gate_proj.in_features out: torch.Tensor = self.pretrain_models["spk_stat"] std, mean = out.chunk(2) spk = ( torch.randn(dim, device=std.device, dtype=torch.float16) .mul_(std) .add_(mean) ) del out, std, mean return spk