File size: 1,838 Bytes
8f52106
01e655b
02e90e4
01e655b
02e90e4
01e655b
 
8f52106
01e655b
 
8f52106
01e655b
8a3a4ec
01e655b
 
8f52106
01e655b
 
8f52106
02e90e4
650b56c
01e655b
 
02e90e4
01e655b
 
02e90e4
 
 
 
 
 
01e655b
 
02e90e4
650b56c
8f52106
 
 
8a3a4ec
 
 
 
 
01e655b
8f52106
 
 
8a3a4ec
374f426
 
01e655b
02e90e4
 
8f52106
 
02e90e4
8f52106
02e90e4
8f52106
 
 
 
02e90e4
 
8f52106
02e90e4
8f52106
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
import threading
import torch
from modules.ChatTTS import ChatTTS
from modules import config
from modules.devices import devices

import logging
import gc

logger = logging.getLogger(__name__)

chat_tts = None
lock = threading.Lock()


def load_chat_tts_in_thread():
    global chat_tts
    if chat_tts:
        return

    logger.info("Loading ChatTTS models")
    chat_tts = ChatTTS.Chat()
    chat_tts.load_models(
        compile=config.runtime_env_vars.compile,
        source="local",
        local_path="./models/ChatTTS",
        device=devices.device,
        dtype=devices.dtype,
        dtype_vocos=devices.dtype_vocos,
        dtype_dvae=devices.dtype_dvae,
        dtype_gpt=devices.dtype_gpt,
        dtype_decoder=devices.dtype_decoder,
    )

    devices.torch_gc()
    logger.info("ChatTTS models loaded")


def initialize_chat_tts():
    with lock:
        if chat_tts is None:
            model_thread = threading.Thread(target=load_chat_tts_in_thread)
            model_thread.start()
            model_thread.join()


def load_chat_tts():
    if chat_tts is None:
        initialize_chat_tts()
    if chat_tts is None:
        raise Exception("Failed to load ChatTTS models")
    return chat_tts


def unload_chat_tts():
    logging.info("Unloading ChatTTS models")
    global chat_tts

    if chat_tts:
        for model_name, model in chat_tts.pretrain_models.items():
            if isinstance(model, torch.nn.Module):
                model.cpu()
                del model
        if torch.cuda.is_available():
            torch.cuda.empty_cache()
    gc.collect()
    chat_tts = None
    logger.info("ChatTTS models unloaded")


def reload_chat_tts():
    logging.info("Reloading ChatTTS models")
    unload_chat_tts()
    instance = load_chat_tts()
    logger.info("ChatTTS models reloaded")
    return instance