from modules.ChatTTS import ChatTTS import torch from modules import config import logging logger = logging.getLogger(__name__) device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu") print(f"device use {device}") chat_tts = None def load_chat_tts(): global chat_tts if chat_tts: return chat_tts chat_tts = ChatTTS.Chat() chat_tts.load_models( compile=config.enable_model_compile, source="local", local_path="./models/ChatTTS", device=device, ) if config.model_config.get("half", False): logging.info("half precision enabled") for model_name, model in chat_tts.pretrain_models.items(): if isinstance(model, torch.nn.Module): model.cpu() if torch.cuda.is_available(): torch.cuda.empty_cache() model.half() if torch.cuda.is_available(): model.cuda() model.eval() logger.log(logging.INFO, f"{model_name} converted to half precision.") return chat_tts