zhzluke96 commited on
Commit
02e90e4
1 Parent(s): d6fe286
modules/ChatTTS/ChatTTS/core.py CHANGED
@@ -101,13 +101,27 @@ class Chat:
101
  tokenizer_path: str = None,
102
  device: str = None,
103
  compile: bool = True,
 
 
 
 
 
104
  ):
105
  if not device:
106
  device = select_device(4096)
107
  self.logger.log(logging.INFO, f"use {device}")
108
 
 
 
 
 
 
109
  if vocos_config_path:
110
- vocos = Vocos.from_hparams(vocos_config_path).to(device).eval()
 
 
 
 
111
  assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
112
  vocos.load_state_dict(torch.load(vocos_ckpt_path))
113
  self.pretrain_models["vocos"] = vocos
@@ -115,7 +129,7 @@ class Chat:
115
 
116
  if dvae_config_path:
117
  cfg = OmegaConf.load(dvae_config_path)
118
- dvae = DVAE(**cfg).to(device).eval()
119
  assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
120
  dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
121
  self.pretrain_models["dvae"] = dvae
@@ -123,7 +137,7 @@ class Chat:
123
 
124
  if gpt_config_path:
125
  cfg = OmegaConf.load(gpt_config_path)
126
- gpt = GPT_warpper(**cfg).to(device).eval()
127
  assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
128
  gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
129
  if compile and "cuda" in str(device):
@@ -136,12 +150,14 @@ class Chat:
136
  assert os.path.exists(
137
  spk_stat_path
138
  ), f"Missing spk_stat.pt: {spk_stat_path}"
139
- self.pretrain_models["spk_stat"] = torch.load(spk_stat_path).to(device)
 
 
140
  self.logger.log(logging.INFO, "gpt loaded.")
141
 
142
  if decoder_config_path:
143
  cfg = OmegaConf.load(decoder_config_path)
144
- decoder = DVAE(**cfg).to(device).eval()
145
  assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
146
  decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
147
  self.pretrain_models["decoder"] = decoder
 
101
  tokenizer_path: str = None,
102
  device: str = None,
103
  compile: bool = True,
104
+ dtype: torch.dtype = torch.float32,
105
+ dtype_vocos: torch.dtype = None,
106
+ dtype_dvae: torch.dtype = None,
107
+ dtype_gpt: torch.dtype = None,
108
+ dtype_decoder: torch.dtype = None,
109
  ):
110
  if not device:
111
  device = select_device(4096)
112
  self.logger.log(logging.INFO, f"use {device}")
113
 
114
+ dtype_vocos = dtype_vocos or dtype
115
+ dtype_dvae = dtype_dvae or dtype
116
+ dtype_gpt = dtype_gpt or dtype
117
+ dtype_decoder = dtype_decoder or dtype
118
+
119
  if vocos_config_path:
120
+ vocos = (
121
+ Vocos.from_hparams(vocos_config_path)
122
+ .to(device=device, dtype=dtype_vocos)
123
+ .eval()
124
+ )
125
  assert vocos_ckpt_path, "vocos_ckpt_path should not be None"
126
  vocos.load_state_dict(torch.load(vocos_ckpt_path))
127
  self.pretrain_models["vocos"] = vocos
 
129
 
130
  if dvae_config_path:
131
  cfg = OmegaConf.load(dvae_config_path)
132
+ dvae = DVAE(**cfg).to(device=device, dtype=dtype_dvae).eval()
133
  assert dvae_ckpt_path, "dvae_ckpt_path should not be None"
134
  dvae.load_state_dict(torch.load(dvae_ckpt_path, map_location=device))
135
  self.pretrain_models["dvae"] = dvae
 
137
 
138
  if gpt_config_path:
139
  cfg = OmegaConf.load(gpt_config_path)
140
+ gpt = GPT_warpper(**cfg).to(device=device, dtype=dtype_gpt).eval()
141
  assert gpt_ckpt_path, "gpt_ckpt_path should not be None"
142
  gpt.load_state_dict(torch.load(gpt_ckpt_path, map_location=device))
143
  if compile and "cuda" in str(device):
 
150
  assert os.path.exists(
151
  spk_stat_path
152
  ), f"Missing spk_stat.pt: {spk_stat_path}"
153
+ self.pretrain_models["spk_stat"] = torch.load(spk_stat_path).to(
154
+ device=device, dtype=dtype
155
+ )
156
  self.logger.log(logging.INFO, "gpt loaded.")
157
 
158
  if decoder_config_path:
159
  cfg = OmegaConf.load(decoder_config_path)
160
+ decoder = DVAE(**cfg).to(device=device, dtype=dtype_decoder).eval()
161
  assert decoder_ckpt_path, "decoder_ckpt_path should not be None"
162
  decoder.load_state_dict(torch.load(decoder_ckpt_path, map_location=device))
163
  self.pretrain_models["decoder"] = decoder
modules/ChatTTS/ChatTTS/model/dvae.py CHANGED
@@ -143,9 +143,9 @@ class DVAE(nn.Module):
143
  else:
144
  vq_feats = inp.detach().clone()
145
 
146
- temp = torch.chunk(vq_feats, 2, dim=1) # flatten trick :)
147
- temp = torch.stack(temp, -1)
148
- vq_feats = temp.reshape(*temp.shape[:2], -1)
149
 
150
  vq_feats = vq_feats.transpose(1, 2)
151
  dec_out = self.decoder(input=vq_feats)
 
143
  else:
144
  vq_feats = inp.detach().clone()
145
 
146
+ vq_feats = vq_feats.view(
147
+ (vq_feats.size(0), 2, vq_feats.size(1)//2, vq_feats.size(2)),
148
+ ).permute(0, 2, 3, 1).flatten(2)
149
 
150
  vq_feats = vq_feats.transpose(1, 2)
151
  dec_out = self.decoder(input=vq_feats)
modules/ChatTTS/ChatTTS/model/gpt.py CHANGED
@@ -190,6 +190,8 @@ class GPT_warpper(nn.Module):
190
  attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
191
 
192
  for i in tqdm(range(max_new_token)):
 
 
193
 
194
  model_input = self.prepare_inputs_for_generation(inputs_ids,
195
  outputs.past_key_values if i!=0 else None,
@@ -250,9 +252,6 @@ class GPT_warpper(nn.Module):
250
 
251
  end_idx = end_idx + (~finish).int()
252
 
253
- if finish.all():
254
- break
255
-
256
  inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
257
  inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
258
 
 
190
  attention_mask_cache[:, :attention_mask.shape[1]] = attention_mask
191
 
192
  for i in tqdm(range(max_new_token)):
193
+ if finish.all():
194
+ continue
195
 
196
  model_input = self.prepare_inputs_for_generation(inputs_ids,
197
  outputs.past_key_values if i!=0 else None,
 
252
 
253
  end_idx = end_idx + (~finish).int()
254
 
 
 
 
255
  inputs_ids = [inputs_ids[idx, start_idx: start_idx+i] for idx, i in enumerate(end_idx.int())]
256
  inputs_ids = [i[:, 0] for i in inputs_ids] if infer_text else inputs_ids
257
 
modules/ChatTTS/ChatTTS/utils/gpu_utils.py CHANGED
@@ -16,8 +16,10 @@ def select_device(min_memory = 2048):
16
  if free_memory_mb < min_memory:
17
  logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
18
  device = torch.device('cpu')
 
 
19
  else:
20
  logger.log(logging.WARNING, f'No GPU found, use CPU instead')
21
  device = torch.device('cpu')
22
 
23
- return device
 
16
  if free_memory_mb < min_memory:
17
  logger.log(logging.WARNING, f'GPU {selected_gpu} has {round(free_memory_mb, 2)} MB memory left.')
18
  device = torch.device('cpu')
19
+ elif torch.backends.mps.is_available() and torch.backends.mps.is_built():
20
+ device = torch.device('mps')
21
  else:
22
  logger.log(logging.WARNING, f'No GPU found, use CPU instead')
23
  device = torch.device('cpu')
24
 
25
+ return device
modules/ChatTTS/ChatTTS/utils/infer_utils.py CHANGED
@@ -101,8 +101,8 @@ character_map = {
101
  "!": ".",
102
  "(": ",",
103
  ")": ",",
104
- # '[': ',',
105
- # ']': ',',
106
  ">": ",",
107
  "<": ",",
108
  "-": ",",
@@ -131,11 +131,11 @@ halfwidth_2_fullwidth_map = {
131
  ">": ">",
132
  "?": "?",
133
  "@": "@",
134
- # '[': '',
135
  "\\": "\",
136
- # ']': '',
137
  "^": "^",
138
- # '_': '_',
139
  "`": "`",
140
  "{": "{",
141
  "|": "|",
 
101
  "!": ".",
102
  "(": ",",
103
  ")": ",",
104
+ "[": ",",
105
+ "]": ",",
106
  ">": ",",
107
  "<": ",",
108
  "-": ",",
 
131
  ">": ">",
132
  "?": "?",
133
  "@": "@",
134
+ "[": "",
135
  "\\": "\",
136
+ "]": "",
137
  "^": "^",
138
+ "_": "_",
139
  "`": "`",
140
  "{": "{",
141
  "|": "|",
modules/SynthesizeSegments.py CHANGED
@@ -1,6 +1,6 @@
1
  import numpy as np
2
  from pydub import AudioSegment
3
- from typing import Any, List, Dict
4
  from scipy.io.wavfile import write
5
  import io
6
  from modules.utils.audio import time_stretch, pitch_shift
@@ -211,7 +211,7 @@ def generate_audio_segment(
211
  return AudioSegment.from_file(byte_io, format="wav")
212
 
213
 
214
- def synthesize_segment(segment: Dict[str, Any]) -> AudioSegment | None:
215
  if "break" in segment:
216
  pause_segment = AudioSegment.silent(duration=segment["break"])
217
  return pause_segment
 
1
  import numpy as np
2
  from pydub import AudioSegment
3
+ from typing import Any, List, Dict, Union
4
  from scipy.io.wavfile import write
5
  import io
6
  from modules.utils.audio import time_stretch, pitch_shift
 
211
  return AudioSegment.from_file(byte_io, format="wav")
212
 
213
 
214
+ def synthesize_segment(segment: Dict[str, Any]) -> Union[AudioSegment, None]:
215
  if "break" in segment:
216
  pause_segment = AudioSegment.silent(duration=segment["break"])
217
  return pause_segment
modules/api/Api.py CHANGED
@@ -27,7 +27,18 @@ class APIManager:
27
  def __init__(self, no_docs=False, exclude_patterns=[]):
28
  self.app = FastAPI(
29
  title="ChatTTS Forge API",
30
- description="ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。\n\nChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax\n\n https://github.com/lenML/ChatTTS-Forge",
 
 
 
 
 
 
 
 
 
 
 
31
  version="0.1.0",
32
  redoc_url=None if no_docs else "/redoc",
33
  docs_url=None if no_docs else "/docs",
 
27
  def __init__(self, no_docs=False, exclude_patterns=[]):
28
  self.app = FastAPI(
29
  title="ChatTTS Forge API",
30
+ description="""
31
+ ChatTTS-Forge 是一个功能强大的文本转语音生成工具,支持通过类 SSML 语法生成丰富的音频长文本,并提供全面的 API 服务,适用于各种场景。<br/>
32
+ ChatTTS-Forge is a powerful text-to-speech generation tool that supports generating rich audio long texts through class SSML syntax
33
+
34
+ 项目地址: [https://github.com/lenML/ChatTTS-Forge](https://github.com/lenML/ChatTTS-Forge)
35
+
36
+ > 所有生成音频的 POST api都无法在此页面调试,调试建议使用 playground <br/>
37
+ > All audio generation POST APIs cannot be debugged on this page, it is recommended to use playground for debugging
38
+
39
+ > 如果你不熟悉本系统,建议从这个一键脚本开始,在colab中尝试一下:<br/>
40
+ > [https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb](https://colab.research.google.com/github/lenML/ChatTTS-Forge/blob/main/colab.ipynb)
41
+ """,
42
  version="0.1.0",
43
  redoc_url=None if no_docs else "/redoc",
44
  docs_url=None if no_docs else "/docs",
modules/api/impl/google_api.py CHANGED
@@ -30,6 +30,7 @@ class SynthesisInput(BaseModel):
30
 
31
  class VoiceSelectionParams(BaseModel):
32
  languageCode: str = "ZH-CN"
 
33
  name: str = "female2"
34
  style: str = ""
35
  temperature: float = 0.3
@@ -160,6 +161,18 @@ async def google_text_synthesize(request: GoogleTextSynthesizeRequest):
160
 
161
 
162
  def setup(app: APIManager):
163
- app.post("/v1/google/text:synthesize", response_model=GoogleTextSynthesizeResponse)(
164
- google_text_synthesize
165
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
30
 
31
  class VoiceSelectionParams(BaseModel):
32
  languageCode: str = "ZH-CN"
33
+
34
  name: str = "female2"
35
  style: str = ""
36
  temperature: float = 0.3
 
161
 
162
 
163
  def setup(app: APIManager):
164
+ app.post(
165
+ "/v1/text:synthesize",
166
+ response_model=GoogleTextSynthesizeResponse,
167
+ description="""
168
+ google api document: <br/>
169
+ [https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize](https://cloud.google.com/text-to-speech/docs/reference/rest/v1/text/synthesize)
170
+
171
+ - 多个属性在本系统中无用仅仅是为了兼容google api
172
+ - voice 中的 topP, topK, temperature 为本系统中的参数
173
+ - voice.name 即 speaker name (或者speaker seed)
174
+ - voice.seed 为 infer seed (可在webui中测试具体作用)
175
+
176
+ - 编码格式影响的是 audioContent 的二进制格式,所以所有format都是返回带有base64数据的json
177
+ """,
178
+ )(google_text_synthesize)
modules/api/impl/models_api.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from modules.api import utils as api_utils
2
+ from modules.api.Api import APIManager
3
+ from modules.models import reload_chat_tts
4
+
5
+
6
+ def setup(app: APIManager):
7
+ @app.get("/v1/models/reload", response_model=api_utils.BaseResponse)
8
+ async def reload_models():
9
+ # Reload models
10
+ reload_chat_tts()
11
+ return api_utils.success_response("Models reloaded")
modules/api/impl/openai_api.py CHANGED
@@ -28,11 +28,11 @@ class AudioSpeechRequest(BaseModel):
28
  model: str = "chattts-4w"
29
  voice: str = "female2"
30
  response_format: Literal["mp3", "wav"] = "mp3"
31
- speed: int = Field(1, ge=1, le=10, description="Speed of the audio")
32
  style: str = ""
33
  # 是否开启batch合成,小于等于1表示不适用batch
34
  # 开启batch合成会自动分割句子
35
- batch_size: int = Field(1, ge=1, le=10, description="Batch size")
36
  spliter_threshold: float = Field(
37
  100, ge=10, le=1024, description="Threshold for sentence spliter"
38
  )
@@ -64,8 +64,8 @@ async def openai_speech_api(
64
  params = api_utils.calc_spk_style(spk=voice, style=style)
65
 
66
  spk = params.get("spk", -1)
67
- seed = params.get("seed", 42)
68
- temperature = params.get("temperature", 0.3)
69
  prompt1 = params.get("prompt1", "")
70
  prompt2 = params.get("prompt2", "")
71
  prefix = params.get("prefix", "")
@@ -107,6 +107,18 @@ async def openai_speech_api(
107
 
108
 
109
  def setup(api_manager: APIManager):
110
- api_manager.post("/v1/openai/audio/speech", response_class=FileResponse)(
111
- openai_speech_api
112
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
28
  model: str = "chattts-4w"
29
  voice: str = "female2"
30
  response_format: Literal["mp3", "wav"] = "mp3"
31
+ speed: float = Field(1, ge=0.1, le=10, description="Speed of the audio")
32
  style: str = ""
33
  # 是否开启batch合成,小于等于1表示不适用batch
34
  # 开启batch合成会自动分割句子
35
+ batch_size: int = Field(1, ge=1, le=20, description="Batch size")
36
  spliter_threshold: float = Field(
37
  100, ge=10, le=1024, description="Threshold for sentence spliter"
38
  )
 
64
  params = api_utils.calc_spk_style(spk=voice, style=style)
65
 
66
  spk = params.get("spk", -1)
67
+ seed = params.get("seed", request.seed or 42)
68
+ temperature = params.get("temperature", request.temperature or 0.3)
69
  prompt1 = params.get("prompt1", "")
70
  prompt2 = params.get("prompt2", "")
71
  prefix = params.get("prefix", "")
 
107
 
108
 
109
  def setup(api_manager: APIManager):
110
+ api_manager.post(
111
+ "/v1/audio/speech",
112
+ response_class=FileResponse,
113
+ description="""
114
+ openai api document:
115
+ [https://platform.openai.com/docs/guides/text-to-speech](https://platform.openai.com/docs/guides/text-to-speech)
116
+
117
+ 以下属性为本系统自定义属性,不在openai文档中:
118
+ - batch_size: 是否开启batch合成,小于等于1表示不使用batch (不推荐)
119
+ - spliter_threshold: 开启batch合成时,句子分割的阈值
120
+ - style: 风格
121
+
122
+ > model 可填任意值
123
+ """,
124
+ )(openai_speech_api)
modules/api/impl/ping_api.py ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ from modules.api import utils as api_utils
2
+ from modules.api.Api import APIManager
3
+
4
+
5
+ def setup(app: APIManager):
6
+ @app.get("/v1/ping", response_model=api_utils.BaseResponse)
7
+ async def ping():
8
+ return {"message": "ok", "data": "pong"}
modules/api/utils.py CHANGED
@@ -1,5 +1,5 @@
1
  from pydantic import BaseModel
2
- from typing import Any
3
 
4
  import torch
5
 
@@ -36,6 +36,10 @@ class BaseResponse(BaseModel):
36
  }
37
 
38
 
 
 
 
 
39
  def wav_to_mp3(wav_data, bitrate="48k"):
40
  audio = AudioSegment.from_wav(
41
  wav_data,
@@ -51,7 +55,7 @@ def to_number(value, t, default=0):
51
  return default
52
 
53
 
54
- def calc_spk_style(spk: str | int, style: str | int):
55
  voice_attrs = {
56
  "spk": None,
57
  "seed": None,
 
1
  from pydantic import BaseModel
2
+ from typing import Any, Union
3
 
4
  import torch
5
 
 
36
  }
37
 
38
 
39
+ def success_response(data: Any, message: str = "Success") -> BaseResponse:
40
+ return BaseResponse(message=message, data=data)
41
+
42
+
43
  def wav_to_mp3(wav_data, bitrate="48k"):
44
  audio = AudioSegment.from_wav(
45
  wav_data,
 
55
  return default
56
 
57
 
58
+ def calc_spk_style(spk: Union[str, int], style: Union[str, int]):
59
  voice_attrs = {
60
  "spk": None,
61
  "seed": None,
modules/config.py CHANGED
@@ -1,11 +1,5 @@
1
- enable_model_compile = False
2
 
3
- lru_size = 64
4
-
5
- args = {}
6
 
7
  api = None
8
-
9
- model_config = {"half": False}
10
-
11
- disable_tqdm = False
 
1
+ from modules.utils.JsonObject import JsonObject
2
 
3
+ runtime_env_vars = JsonObject({})
 
 
4
 
5
  api = None
 
 
 
 
modules/devices/__init__.py ADDED
File without changes
modules/devices/devices.py ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from functools import lru_cache
2
+ import sys
3
+ import torch
4
+ from modules import config
5
+
6
+ import logging
7
+
8
+ logger = logging.getLogger(__name__)
9
+
10
+ if sys.platform == "darwin":
11
+ from modules.devices import mac_devices
12
+
13
+
14
+ def has_mps() -> bool:
15
+ if sys.platform != "darwin":
16
+ return False
17
+ else:
18
+ return mac_devices.has_mps
19
+
20
+
21
+ def get_cuda_device_id():
22
+ return (
23
+ int(config.runtime_env_vars.device_id)
24
+ if config.runtime_env_vars.device_id is not None
25
+ and config.runtime_env_vars.device_id.isdigit()
26
+ else 0
27
+ ) or torch.cuda.current_device()
28
+
29
+
30
+ def get_cuda_device_string():
31
+ if config.runtime_env_vars.device_id is not None:
32
+ return f"cuda:{config.runtime_env_vars.device_id}"
33
+
34
+ return "cuda"
35
+
36
+
37
+ def get_available_gpus() -> list[tuple[int, int]]:
38
+ """
39
+ Get the list of available GPUs and their free memory.
40
+
41
+ :return: A list of tuples where each tuple contains (GPU index, free memory in bytes).
42
+ """
43
+ available_gpus = []
44
+ for i in range(torch.cuda.device_count()):
45
+ props = torch.cuda.get_device_properties(i)
46
+ free_memory = props.total_memory - torch.cuda.memory_reserved(i)
47
+ available_gpus.append((i, free_memory))
48
+ return available_gpus
49
+
50
+
51
+ def get_memory_available_gpus(min_memory=2048):
52
+ available_gpus = get_available_gpus()
53
+ memory_available_gpus = [
54
+ gpu for gpu, free_memory in available_gpus if free_memory > min_memory
55
+ ]
56
+ return memory_available_gpus
57
+
58
+
59
+ def get_target_device_id_or_memory_available_gpu():
60
+ memory_available_gpus = get_memory_available_gpus()
61
+ device_id = get_cuda_device_id()
62
+ if device_id not in memory_available_gpus:
63
+ if len(memory_available_gpus) != 0:
64
+ logger.warning(
65
+ f"Device {device_id} is not available or does not have enough memory. will try to use {memory_available_gpus}"
66
+ )
67
+ config.runtime_env_vars.device_id = str(memory_available_gpus[0])
68
+ else:
69
+ logger.warning(
70
+ f"Device {device_id} is not available or does not have enough memory. Using CPU instead."
71
+ )
72
+ return "cpu"
73
+ return get_cuda_device_string()
74
+
75
+
76
+ def get_optimal_device_name():
77
+ if config.runtime_env_vars.use_cpu:
78
+ return "cpu"
79
+
80
+ if torch.cuda.is_available():
81
+ return get_target_device_id_or_memory_available_gpu()
82
+
83
+ if has_mps():
84
+ return "mps"
85
+
86
+ return "cpu"
87
+
88
+
89
+ def get_optimal_device():
90
+ return torch.device(get_optimal_device_name())
91
+
92
+
93
+ def get_device_for(task):
94
+ if task in config.cmd_opts.use_cpu or "all" in config.cmd_opts.use_cpu:
95
+ return cpu
96
+
97
+ return get_optimal_device()
98
+
99
+
100
+ def torch_gc():
101
+ try:
102
+ if torch.cuda.is_available():
103
+ with torch.cuda.device(get_cuda_device_string()):
104
+ torch.cuda.empty_cache()
105
+ torch.cuda.ipc_collect()
106
+
107
+ if has_mps():
108
+ mac_devices.torch_mps_gc()
109
+ except Exception as e:
110
+ logger.error(f"Error in torch_gc", exc_info=True)
111
+
112
+
113
+ cpu: torch.device = torch.device("cpu")
114
+ device: torch.device = get_optimal_device()
115
+ dtype: torch.dtype = torch.float32
116
+ dtype_dvae: torch.dtype = torch.float32
117
+ dtype_vocos: torch.dtype = torch.float32
118
+ dtype_gpt: torch.dtype = torch.float32
119
+ dtype_decoder: torch.dtype = torch.float32
120
+
121
+
122
+ def reset_device():
123
+ if config.runtime_env_vars.half:
124
+ global dtype
125
+ global dtype_dvae
126
+ global dtype_vocos
127
+ global dtype_gpt
128
+ global dtype_decoder
129
+ dtype = torch.float16
130
+ dtype_dvae = torch.float16
131
+ dtype_vocos = torch.float16
132
+ dtype_gpt = torch.float16
133
+ dtype_decoder = torch.float16
134
+
135
+ logger.info("Using half precision: torch.float16")
136
+
137
+ if (
138
+ config.runtime_env_vars.device_id is not None
139
+ or config.runtime_env_vars.use_cpu is not None
140
+ ):
141
+ global device
142
+ device = get_optimal_device()
143
+
144
+ logger.info(f"Using device: {device}")
145
+
146
+
147
+ @lru_cache
148
+ def first_time_calculation():
149
+ """
150
+ just do any calculation with pytorch layers - the first time this is done it allocaltes about 700MB of memory and
151
+ spends about 2.7 seconds doing that, at least wih NVidia.
152
+ """
153
+
154
+ x = torch.zeros((1, 1)).to(device, dtype)
155
+ linear = torch.nn.Linear(1, 1).to(device, dtype)
156
+ linear(x)
157
+
158
+ x = torch.zeros((1, 1, 3, 3)).to(device, dtype)
159
+ conv2d = torch.nn.Conv2d(1, 1, (3, 3)).to(device, dtype)
160
+ conv2d(x)
modules/devices/mac_devices.py ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import logging
3
+ from packaging import version
4
+ import torch.backends
5
+ import torch.backends.mps
6
+
7
+ logger = logging.getLogger(__name__)
8
+
9
+
10
+ def check_for_mps() -> bool:
11
+ if version.parse(torch.__version__) <= version.parse("2.0.1"):
12
+ if not getattr(torch, "has_mps", False):
13
+ return False
14
+ try:
15
+ torch.zeros(1).to(torch.device("mps"))
16
+ return True
17
+ except Exception:
18
+ return False
19
+ else:
20
+ try:
21
+ return torch.backends.mps.is_available() and torch.backends.mps.is_built()
22
+ except:
23
+ logger.warning("MPS garbage collection failed", exc_info=True)
24
+ return False
25
+
26
+
27
+ has_mps = check_for_mps()
28
+
29
+
30
+ def torch_mps_gc() -> None:
31
+ try:
32
+ from torch.mps import empty_cache
33
+
34
+ empty_cache()
35
+ except Exception:
36
+ logger.warning("MPS garbage collection failed", exc_info=True)
37
+
38
+
39
+ if __name__ == "__main__":
40
+ print(torch.__version__)
41
+ print(has_mps)
42
+ torch_mps_gc()
modules/generate_audio.py CHANGED
@@ -8,18 +8,20 @@ from modules import models, config
8
 
9
  import logging
10
 
11
- from modules import devices
 
 
 
12
 
13
  logger = logging.getLogger(__name__)
14
 
15
 
16
- @torch.inference_mode()
17
  def generate_audio(
18
  text: str,
19
  temperature: float = 0.3,
20
  top_P: float = 0.7,
21
  top_K: float = 20,
22
- spk: int | Speaker = -1,
23
  infer_seed: int = -1,
24
  use_decoder: bool = True,
25
  prompt1: str = "",
@@ -48,7 +50,7 @@ def generate_audio_batch(
48
  temperature: float = 0.3,
49
  top_P: float = 0.7,
50
  top_K: float = 20,
51
- spk: int | Speaker = -1,
52
  infer_seed: int = -1,
53
  use_decoder: bool = True,
54
  prompt1: str = "",
@@ -65,7 +67,7 @@ def generate_audio_batch(
65
  "prompt2": prompt2 or "",
66
  "prefix": prefix or "",
67
  "repetition_penalty": 1.0,
68
- "disable_tqdm": config.disable_tqdm,
69
  }
70
 
71
  if isinstance(spk, int):
@@ -103,6 +105,32 @@ def generate_audio_batch(
103
  return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
104
 
105
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
106
  if __name__ == "__main__":
107
  import soundfile as sf
108
 
 
8
 
9
  import logging
10
 
11
+ from modules.devices import devices
12
+ from typing import Union
13
+
14
+ from modules.utils.cache import conditional_cache
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
 
19
  def generate_audio(
20
  text: str,
21
  temperature: float = 0.3,
22
  top_P: float = 0.7,
23
  top_K: float = 20,
24
+ spk: Union[int, Speaker] = -1,
25
  infer_seed: int = -1,
26
  use_decoder: bool = True,
27
  prompt1: str = "",
 
50
  temperature: float = 0.3,
51
  top_P: float = 0.7,
52
  top_K: float = 20,
53
+ spk: Union[int, Speaker] = -1,
54
  infer_seed: int = -1,
55
  use_decoder: bool = True,
56
  prompt1: str = "",
 
67
  "prompt2": prompt2 or "",
68
  "prefix": prefix or "",
69
  "repetition_penalty": 1.0,
70
+ "disable_tqdm": config.runtime_env_vars.off_tqdm,
71
  }
72
 
73
  if isinstance(spk, int):
 
105
  return [(sample_rate, np.array(wav).flatten().astype(np.float32)) for wav in wavs]
106
 
107
 
108
+ lru_cache_enabled = False
109
+
110
+
111
+ def setup_lru_cache():
112
+ global generate_audio_batch
113
+ global lru_cache_enabled
114
+
115
+ if lru_cache_enabled:
116
+ return
117
+ lru_cache_enabled = True
118
+
119
+ def should_cache(*args, **kwargs):
120
+ spk_seed = kwargs.get("spk", -1)
121
+ infer_seed = kwargs.get("infer_seed", -1)
122
+ return spk_seed != -1 and infer_seed != -1
123
+
124
+ lru_size = config.runtime_env_vars.lru_size
125
+ if isinstance(lru_size, int):
126
+ generate_audio_batch = conditional_cache(lru_size, should_cache)(
127
+ generate_audio_batch
128
+ )
129
+ logger.info(f"LRU cache enabled with size {lru_size}")
130
+ else:
131
+ logger.debug(f"LRU cache failed to enable, invalid size {lru_size}")
132
+
133
+
134
  if __name__ == "__main__":
135
  import soundfile as sf
136
 
modules/models.py CHANGED
@@ -1,15 +1,11 @@
1
- from modules.ChatTTS import ChatTTS
2
  import torch
3
-
4
  from modules import config
 
5
 
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
9
-
10
- device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
11
- print(f"device use {device}")
12
-
13
  chat_tts = None
14
 
15
 
@@ -17,25 +13,33 @@ def load_chat_tts():
17
  global chat_tts
18
  if chat_tts:
19
  return chat_tts
 
20
  chat_tts = ChatTTS.Chat()
21
  chat_tts.load_models(
22
- compile=config.enable_model_compile,
23
  source="local",
24
  local_path="./models/ChatTTS",
25
- device=device,
 
 
 
 
 
26
  )
27
 
28
- if config.model_config.get("half", False):
29
- logging.info("half precision enabled")
30
- for model_name, model in chat_tts.pretrain_models.items():
31
- if isinstance(model, torch.nn.Module):
32
- model.cpu()
33
- if torch.cuda.is_available():
34
- torch.cuda.empty_cache()
35
- model.half()
36
- if torch.cuda.is_available():
37
- model.cuda()
38
- model.eval()
39
- logger.log(logging.INFO, f"{model_name} converted to half precision.")
40
 
41
  return chat_tts
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import torch
2
+ from modules.ChatTTS import ChatTTS
3
  from modules import config
4
+ from modules.devices import devices
5
 
6
  import logging
7
 
8
  logger = logging.getLogger(__name__)
 
 
 
 
9
  chat_tts = None
10
 
11
 
 
13
  global chat_tts
14
  if chat_tts:
15
  return chat_tts
16
+
17
  chat_tts = ChatTTS.Chat()
18
  chat_tts.load_models(
19
+ compile=config.runtime_env_vars.compile,
20
  source="local",
21
  local_path="./models/ChatTTS",
22
+ device=devices.device,
23
+ dtype=devices.dtype,
24
+ dtype_vocos=devices.dtype_vocos,
25
+ dtype_dvae=devices.dtype_dvae,
26
+ dtype_gpt=devices.dtype_gpt,
27
+ dtype_decoder=devices.dtype_decoder,
28
  )
29
 
30
+ devices.torch_gc()
 
 
 
 
 
 
 
 
 
 
 
31
 
32
  return chat_tts
33
+
34
+
35
+ def reload_chat_tts():
36
+ logging.info("Reloading ChatTTS models")
37
+ global chat_tts
38
+ if chat_tts:
39
+ if torch.cuda.is_available():
40
+ for model_name, model in chat_tts.pretrain_models.items():
41
+ if isinstance(model, torch.nn.Module):
42
+ model.cpu()
43
+ torch.cuda.empty_cache()
44
+ chat_tts = None
45
+ return load_chat_tts()
modules/normalization.py CHANGED
@@ -1,6 +1,15 @@
1
  from modules.utils.zh_normalization.text_normlization import *
2
  import emojiswitch
3
  from modules.utils.markdown import markdown_to_text
 
 
 
 
 
 
 
 
 
4
 
5
  post_normalize_pipeline = []
6
  pre_normalize_pipeline = []
@@ -87,12 +96,17 @@ character_map = {
87
  ">": ",",
88
  "<": ",",
89
  "-": ",",
 
 
 
90
  }
91
 
92
  character_to_word = {
93
  " & ": " and ",
94
  }
95
 
 
 
96
 
97
  @post_normalize()
98
  def apply_character_to_word(text):
@@ -109,7 +123,8 @@ def apply_character_map(text):
109
 
110
  @post_normalize()
111
  def apply_emoji_map(text):
112
- return emojiswitch.demojize(text, delimiters=("", ""), lang="zh")
 
113
 
114
 
115
  @post_normalize()
@@ -122,6 +137,26 @@ def insert_spaces_between_uppercase(s):
122
  )
123
 
124
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
125
  @pre_normalize()
126
  def apply_markdown_to_text(text):
127
  if is_markdown(text):
@@ -186,7 +221,7 @@ def sentence_normalize(sentence_text: str):
186
  pattern = re.compile(r"(\[.+?\])|([^[]+)")
187
 
188
  def normalize_part(part):
189
- sentences = tx.normalize(part)
190
  dest_text = ""
191
  for sentence in sentences:
192
  sentence = apply_post_normalize(sentence)
@@ -244,6 +279,16 @@ console.log('1')
244
  “我们是玫瑰花。”花儿们说道。
245
  “啊!”小王子说……。
246
  """,
 
 
 
 
 
 
 
 
 
 
247
  ]
248
 
249
  for i, test_case in enumerate(test_cases):
 
1
  from modules.utils.zh_normalization.text_normlization import *
2
  import emojiswitch
3
  from modules.utils.markdown import markdown_to_text
4
+ from modules import models
5
+ import re
6
+
7
+
8
+ def is_chinese(text):
9
+ # 中文字符的 Unicode 范围是 \u4e00-\u9fff
10
+ chinese_pattern = re.compile(r"[\u4e00-\u9fff]")
11
+ return bool(chinese_pattern.search(text))
12
+
13
 
14
  post_normalize_pipeline = []
15
  pre_normalize_pipeline = []
 
96
  ">": ",",
97
  "<": ",",
98
  "-": ",",
99
+ "~": " ",
100
+ "~": " ",
101
+ "/": " ",
102
  }
103
 
104
  character_to_word = {
105
  " & ": " and ",
106
  }
107
 
108
+ ## ---------- post normalize ----------
109
+
110
 
111
  @post_normalize()
112
  def apply_character_to_word(text):
 
123
 
124
  @post_normalize()
125
  def apply_emoji_map(text):
126
+ lang = "zh" if is_chinese(text) else "en"
127
+ return emojiswitch.demojize(text, delimiters=("", ""), lang=lang)
128
 
129
 
130
  @post_normalize()
 
137
  )
138
 
139
 
140
+ @post_normalize()
141
+ def replace_unk_tokens(text):
142
+ """
143
+ 把不在字典里的字符替换为 " , "
144
+ """
145
+ chat_tts = models.load_chat_tts()
146
+ tokenizer = chat_tts.pretrain_models["tokenizer"]
147
+ vocab = tokenizer.get_vocab()
148
+ vocab_set = set(vocab.keys())
149
+ # 添加所有英语字符
150
+ vocab_set.update(set("abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ"))
151
+ vocab_set.update(set(" \n\r\t"))
152
+ replaced_chars = [char if char in vocab_set else " , " for char in text]
153
+ output_text = "".join(replaced_chars)
154
+ return output_text
155
+
156
+
157
+ ## ---------- pre normalize ----------
158
+
159
+
160
  @pre_normalize()
161
  def apply_markdown_to_text(text):
162
  if is_markdown(text):
 
221
  pattern = re.compile(r"(\[.+?\])|([^[]+)")
222
 
223
  def normalize_part(part):
224
+ sentences = tx.normalize(part) if is_chinese(part) else [part]
225
  dest_text = ""
226
  for sentence in sentences:
227
  sentence = apply_post_normalize(sentence)
 
279
  “我们是玫瑰花。”花儿们说道。
280
  “啊!”小王子说……。
281
  """,
282
+ """
283
+ State-of-the-art Machine Learning for PyTorch, TensorFlow, and JAX.
284
+
285
+ 🤗 Transformers provides APIs and tools to easily download and train state-of-the-art pretrained models. Using pretrained models can reduce your compute costs, carbon footprint, and save you the time and resources required to train a model from scratch. These models support common tasks in different modalities, such as:
286
+
287
+ 📝 Natural Language Processing: text classification, named entity recognition, question answering, language modeling, summarization, translation, multiple choice, and text generation.
288
+ 🖼️ Computer Vision: image classification, object detection, and segmentation.
289
+ 🗣️ Audio: automatic speech recognition and audio classification.
290
+ 🐙 Multimodal: table question answering, optical character recognition, information extraction from scanned documents, video classification, and visual question answering.
291
+ """,
292
  ]
293
 
294
  for i, test_case in enumerate(test_cases):
modules/refiner.py CHANGED
@@ -29,7 +29,7 @@ def refine_text(
29
  "temperature": temperature,
30
  "repetition_penalty": repetition_penalty,
31
  "max_new_token": max_new_token,
32
- "disable_tqdm": config.disable_tqdm,
33
  },
34
  do_text_normalization=False,
35
  )
 
29
  "temperature": temperature,
30
  "repetition_penalty": repetition_penalty,
31
  "max_new_token": max_new_token,
32
+ "disable_tqdm": config.runtime_env_vars.off_tqdm,
33
  },
34
  do_text_normalization=False,
35
  )
modules/speaker.py CHANGED
@@ -1,4 +1,5 @@
1
  import os
 
2
  import torch
3
 
4
  from modules import models
@@ -53,6 +54,14 @@ class Speaker:
53
 
54
  return is_update
55
 
 
 
 
 
 
 
 
 
56
 
57
  # 每个speaker就是一个 emb 文件 .pt
58
  # 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
@@ -105,13 +114,13 @@ class SpeakerManager:
105
  self.refresh_speakers()
106
  return speaker
107
 
108
- def get_speaker(self, name) -> Speaker | None:
109
  for speaker in self.speakers.values():
110
  if speaker.name == name:
111
  return speaker
112
  return None
113
 
114
- def get_speaker_by_id(self, id) -> Speaker | None:
115
  for speaker in self.speakers.values():
116
  if str(speaker.id) == str(id):
117
  return speaker
 
1
  import os
2
+ from typing import Union
3
  import torch
4
 
5
  from modules import models
 
54
 
55
  return is_update
56
 
57
+ def __hash__(self):
58
+ return hash(str(self.id))
59
+
60
+ def __eq__(self, other):
61
+ if not isinstance(other, Speaker):
62
+ return False
63
+ return str(self.id) == str(other.id)
64
+
65
 
66
  # 每个speaker就是一个 emb 文件 .pt
67
  # 管理 speaker 就是管理 ./data/speaker/ 下的所有 speaker
 
114
  self.refresh_speakers()
115
  return speaker
116
 
117
+ def get_speaker(self, name) -> Union[Speaker, None]:
118
  for speaker in self.speakers.values():
119
  if speaker.name == name:
120
  return speaker
121
  return None
122
 
123
+ def get_speaker_by_id(self, id) -> Union[Speaker, None]:
124
  for speaker in self.speakers.values():
125
  if str(speaker.id) == str(id):
126
  return speaker
modules/synthesize_audio.py CHANGED
@@ -1,4 +1,5 @@
1
  import io
 
2
  from modules.SentenceSplitter import SentenceSplitter
3
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
4
 
@@ -14,7 +15,7 @@ def synthesize_audio(
14
  temperature: float = 0.3,
15
  top_P: float = 0.7,
16
  top_K: float = 20,
17
- spk: int | Speaker = -1,
18
  infer_seed: int = -1,
19
  use_decoder: bool = True,
20
  prompt1: str = "",
 
1
  import io
2
+ from typing import Union
3
  from modules.SentenceSplitter import SentenceSplitter
4
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
5
 
 
15
  temperature: float = 0.3,
16
  top_P: float = 0.7,
17
  top_K: float = 20,
18
+ spk: Union[int, Speaker] = -1,
19
  infer_seed: int = -1,
20
  use_decoder: bool = True,
21
  prompt1: str = "",
modules/utils/JsonObject.py ADDED
@@ -0,0 +1,113 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ class JsonObject:
2
+ def __init__(self, initial_dict=None):
3
+ """
4
+ Initialize the JsonObject with an optional initial dictionary.
5
+
6
+ :param initial_dict: A dictionary to initialize the JsonObject.
7
+ """
8
+ # If no initial dictionary is provided, use an empty dictionary
9
+ self._dict_obj = initial_dict if initial_dict is not None else {}
10
+
11
+ def __getattr__(self, name):
12
+ """
13
+ Get an attribute value. If the attribute does not exist,
14
+ look it up in the internal dictionary.
15
+
16
+ :param name: The name of the attribute.
17
+ :return: The value of the attribute.
18
+ :raises AttributeError: If the attribute is not found in the dictionary.
19
+ """
20
+ try:
21
+ return self._dict_obj[name]
22
+ except KeyError:
23
+ return None
24
+
25
+ def __setattr__(self, name, value):
26
+ """
27
+ Set an attribute value. If the attribute name is '_dict_obj',
28
+ set it directly as an instance attribute. Otherwise,
29
+ store it in the internal dictionary.
30
+
31
+ :param name: The name of the attribute.
32
+ :param value: The value to set.
33
+ """
34
+ if name == "_dict_obj":
35
+ super().__setattr__(name, value)
36
+ else:
37
+ self._dict_obj[name] = value
38
+
39
+ def __delattr__(self, name):
40
+ """
41
+ Delete an attribute. If the attribute does not exist,
42
+ look it up in the internal dictionary and remove it.
43
+
44
+ :param name: The name of the attribute.
45
+ :raises AttributeError: If the attribute is not found in the dictionary.
46
+ """
47
+ try:
48
+ del self._dict_obj[name]
49
+ except KeyError:
50
+ return
51
+
52
+ def __getitem__(self, key):
53
+ """
54
+ Get an item value from the internal dictionary.
55
+
56
+ :param key: The key of the item.
57
+ :return: The value of the item.
58
+ :raises KeyError: If the key is not found in the dictionary.
59
+ """
60
+ if key not in self._dict_obj:
61
+ return None
62
+ return self._dict_obj[key]
63
+
64
+ def __setitem__(self, key, value):
65
+ """
66
+ Set an item value in the internal dictionary.
67
+
68
+ :param key: The key of the item.
69
+ :param value: The value to set.
70
+ """
71
+ self._dict_obj[key] = value
72
+
73
+ def __delitem__(self, key):
74
+ """
75
+ Delete an item from the internal dictionary.
76
+
77
+ :param key: The key of the item.
78
+ :raises KeyError: If the key is not found in the dictionary.
79
+ """
80
+ del self._dict_obj[key]
81
+
82
+ def to_dict(self):
83
+ """
84
+ Convert the JsonObject back to a regular dictionary.
85
+
86
+ :return: The internal dictionary.
87
+ """
88
+ return self._dict_obj
89
+
90
+ def has_key(self, key):
91
+ """
92
+ Check if the key exists in the internal dictionary.
93
+
94
+ :param key: The key to check.
95
+ :return: True if the key exists, False otherwise.
96
+ """
97
+ return key in self._dict_obj
98
+
99
+ def keys(self):
100
+ """
101
+ Get a list of keys in the internal dictionary.
102
+
103
+ :return: A list of keys.
104
+ """
105
+ return self._dict_obj.keys()
106
+
107
+ def values(self):
108
+ """
109
+ Get a list of values in the internal dictionary.
110
+
111
+ :return: A list of values.
112
+ """
113
+ return self._dict_obj.values()
modules/utils/cache.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, TypeVar, Any
2
+ from typing_extensions import ParamSpec
3
+
4
+ from functools import lru_cache, _CacheInfo
5
+
6
+
7
+ def conditional_cache(maxsize: int, condition: Callable):
8
+ def decorator(func):
9
+ @lru_cache_ext(maxsize=maxsize)
10
+ def cached_func(*args, **kwargs):
11
+ return func(*args, **kwargs)
12
+
13
+ def wrapper(*args, **kwargs):
14
+ if condition(*args, **kwargs):
15
+ return cached_func(*args, **kwargs)
16
+ else:
17
+ return func(*args, **kwargs)
18
+
19
+ return wrapper
20
+
21
+ return decorator
22
+
23
+
24
+ def hash_list(l: list) -> int:
25
+ __hash = 0
26
+ for i, e in enumerate(l):
27
+ __hash = hash((__hash, i, hash_item(e)))
28
+ return __hash
29
+
30
+
31
+ def hash_dict(d: dict) -> int:
32
+ __hash = 0
33
+ for k, v in d.items():
34
+ __hash = hash((__hash, k, hash_item(v)))
35
+ return __hash
36
+
37
+
38
+ def hash_item(e) -> int:
39
+ if hasattr(e, "__hash__") and callable(e.__hash__):
40
+ try:
41
+ return hash(e)
42
+ except TypeError:
43
+ pass
44
+ if isinstance(e, (list, set, tuple)):
45
+ return hash_list(list(e))
46
+ elif isinstance(e, (dict)):
47
+ return hash_dict(e)
48
+ else:
49
+ raise TypeError(f"unhashable type: {e.__class__}")
50
+
51
+
52
+ PT = ParamSpec("PT")
53
+ RT = TypeVar("RT")
54
+
55
+
56
+ def lru_cache_ext(
57
+ *opts, hashfunc: Callable[..., int] = hash_item, **kwopts
58
+ ) -> Callable[[Callable[PT, RT]], Callable[PT, RT]]:
59
+ def decorator(func: Callable[PT, RT]) -> Callable[PT, RT]:
60
+ class _lru_cache_ext_wrapper:
61
+ args: tuple
62
+ kwargs: dict[str, Any]
63
+
64
+ def cache_info(self) -> _CacheInfo: ...
65
+ def cache_clear(self) -> None: ...
66
+
67
+ @classmethod
68
+ @lru_cache(*opts, **kwopts)
69
+ def cached_func(cls, args_hash: int) -> RT:
70
+ return func(*cls.args, **cls.kwargs)
71
+
72
+ @classmethod
73
+ def __call__(cls, *args: PT.args, **kwargs: PT.kwargs) -> RT:
74
+ __hash = hashfunc(
75
+ (
76
+ id(func),
77
+ *[hashfunc(a) for a in args],
78
+ *[(hashfunc(k), hashfunc(v)) for k, v in kwargs.items()],
79
+ )
80
+ )
81
+
82
+ cls.args = args
83
+ cls.kwargs = kwargs
84
+
85
+ cls.cache_info = cls.cached_func.cache_info
86
+ cls.cache_clear = cls.cached_func.cache_clear
87
+
88
+ return cls.cached_func(__hash)
89
+
90
+ return _lru_cache_ext_wrapper()
91
+
92
+ return decorator
modules/utils/zh_normalization/text_normlization.py CHANGED
@@ -72,9 +72,9 @@ class TextNormalizer():
72
  return sentences
73
 
74
  def _post_replace(self, sentence: str) -> str:
75
- sentence = sentence.replace('/', '每')
76
- sentence = sentence.replace('~', '至')
77
- sentence = sentence.replace('~', '至')
78
  sentence = sentence.replace('①', '一')
79
  sentence = sentence.replace('②', '二')
80
  sentence = sentence.replace('③', '三')
 
72
  return sentences
73
 
74
  def _post_replace(self, sentence: str) -> str:
75
+ # sentence = sentence.replace('/', '每')
76
+ # sentence = sentence.replace('~', '至')
77
+ # sentence = sentence.replace('~', '至')
78
  sentence = sentence.replace('①', '一')
79
  sentence = sentence.replace('②', '二')
80
  sentence = sentence.replace('③', '三')
webui.py CHANGED
@@ -14,9 +14,11 @@ except:
14
  import os
15
  import logging
16
 
17
- from numpy import clip
18
 
 
19
  from modules.synthesize_audio import synthesize_audio
 
20
 
21
  logging.basicConfig(
22
  level=os.getenv("LOG_LEVEL", "INFO"),
@@ -25,20 +27,17 @@ logging.basicConfig(
25
 
26
 
27
  import gradio as gr
28
- import io
29
- import re
30
- import numpy as np
31
 
32
  import torch
33
 
34
  from modules.ssml import parse_ssml
35
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
36
- from modules.generate_audio import generate_audio, generate_audio_batch
37
 
38
  from modules.speaker import speaker_mgr
39
  from modules.data import styles_mgr
40
 
41
  from modules.api.utils import calc_spk_style
 
42
 
43
  from modules.normalization import text_normalize
44
  from modules import refiner, config
@@ -147,7 +146,7 @@ def tts_generate(
147
  prompt1 = prompt1 or params.get("prompt1", "")
148
  prompt2 = prompt2 or params.get("prompt2", "")
149
 
150
- infer_seed = clip(infer_seed, -1, 2**32 - 1)
151
  infer_seed = int(infer_seed)
152
 
153
  if not disable_normalize:
@@ -869,31 +868,59 @@ if __name__ == "__main__":
869
  type=int,
870
  help="Max batch size for TTS",
871
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
872
 
873
  args = parser.parse_args()
874
 
875
- server_name = env.get_env_or_arg(args, "server_name", "0.0.0.0", str)
876
- server_port = env.get_env_or_arg(args, "server_port", 7860, int)
877
- share = env.get_env_or_arg(args, "share", False, bool)
878
- debug = env.get_env_or_arg(args, "debug", False, bool)
879
- auth = env.get_env_or_arg(args, "auth", None, str)
880
- half = env.get_env_or_arg(args, "half", False, bool)
881
- off_tqdm = env.get_env_or_arg(args, "off_tqdm", False, bool)
882
-
883
- webui_config["tts_max"] = env.get_env_or_arg(args, "tts_max_len", 1000, int)
884
- webui_config["ssml_max"] = env.get_env_or_arg(args, "ssml_max_len", 5000, int)
885
- webui_config["max_batch_size"] = env.get_env_or_arg(args, "max_batch_size", 8, int)
 
 
 
 
 
 
 
 
 
 
886
 
887
  demo = create_interface()
888
 
889
  if auth:
890
  auth = tuple(auth.split(":"))
891
 
892
- if half:
893
- config.model_config["half"] = True
894
-
895
- if off_tqdm:
896
- config.disable_tqdm = True
897
 
898
  demo.queue().launch(
899
  server_name=server_name,
 
14
  import os
15
  import logging
16
 
17
+ import numpy as np
18
 
19
+ from modules.devices import devices
20
  from modules.synthesize_audio import synthesize_audio
21
+ from modules.utils.cache import conditional_cache
22
 
23
  logging.basicConfig(
24
  level=os.getenv("LOG_LEVEL", "INFO"),
 
27
 
28
 
29
  import gradio as gr
 
 
 
30
 
31
  import torch
32
 
33
  from modules.ssml import parse_ssml
34
  from modules.SynthesizeSegments import SynthesizeSegments, combine_audio_segments
 
35
 
36
  from modules.speaker import speaker_mgr
37
  from modules.data import styles_mgr
38
 
39
  from modules.api.utils import calc_spk_style
40
+ import modules.generate_audio as generate
41
 
42
  from modules.normalization import text_normalize
43
  from modules import refiner, config
 
146
  prompt1 = prompt1 or params.get("prompt1", "")
147
  prompt2 = prompt2 or params.get("prompt2", "")
148
 
149
+ infer_seed = np.clip(infer_seed, -1, 2**32 - 1)
150
  infer_seed = int(infer_seed)
151
 
152
  if not disable_normalize:
 
868
  type=int,
869
  help="Max batch size for TTS",
870
  )
871
+ parser.add_argument(
872
+ "--lru_size",
873
+ type=int,
874
+ default=64,
875
+ help="Set the size of the request cache pool, set it to 0 will disable lru_cache",
876
+ )
877
+ parser.add_argument(
878
+ "--device_id",
879
+ type=str,
880
+ help="Select the default CUDA device to use (export CUDA_VISIBLE_DEVICES=0,1,etc might be needed before)",
881
+ default=None,
882
+ )
883
+ parser.add_argument(
884
+ "--use_cpu",
885
+ nargs="+",
886
+ help="use CPU as torch device for specified modules",
887
+ default=[],
888
+ type=str.lower,
889
+ )
890
+ parser.add_argument("--compile", action="store_true", help="Enable model compile")
891
 
892
  args = parser.parse_args()
893
 
894
+ def get_and_update_env(*args):
895
+ val = env.get_env_or_arg(*args)
896
+ key = args[1]
897
+ config.runtime_env_vars[key] = val
898
+ return val
899
+
900
+ server_name = get_and_update_env(args, "server_name", "0.0.0.0", str)
901
+ server_port = get_and_update_env(args, "server_port", 7860, int)
902
+ share = get_and_update_env(args, "share", False, bool)
903
+ debug = get_and_update_env(args, "debug", False, bool)
904
+ auth = get_and_update_env(args, "auth", None, str)
905
+ half = get_and_update_env(args, "half", False, bool)
906
+ off_tqdm = get_and_update_env(args, "off_tqdm", False, bool)
907
+ lru_size = get_and_update_env(args, "lru_size", 64, int)
908
+ device_id = get_and_update_env(args, "device_id", None, str)
909
+ use_cpu = get_and_update_env(args, "use_cpu", [], list)
910
+ compile = get_and_update_env(args, "compile", False, bool)
911
+
912
+ webui_config["tts_max"] = get_and_update_env(args, "tts_max_len", 1000, int)
913
+ webui_config["ssml_max"] = get_and_update_env(args, "ssml_max_len", 5000, int)
914
+ webui_config["max_batch_size"] = get_and_update_env(args, "max_batch_size", 8, int)
915
 
916
  demo = create_interface()
917
 
918
  if auth:
919
  auth = tuple(auth.split(":"))
920
 
921
+ generate.setup_lru_cache()
922
+ devices.reset_device()
923
+ devices.first_time_calculation()
 
 
924
 
925
  demo.queue().launch(
926
  server_name=server_name,