MZhao-LEGION commited on
Commit
84fef35
1 Parent(s): f2c4c94

multilingual model!

Browse files
Files changed (7) hide show
  1. Data/TalkFlower_CNzh/config.json +0 -96
  2. app.py +1 -1
  3. config.yml +13 -13
  4. emo_gen.py +16 -23
  5. infer.py +18 -21
  6. presets.py +114 -55
  7. utils.py +70 -0
Data/TalkFlower_CNzh/config.json DELETED
@@ -1,96 +0,0 @@
1
- {
2
- "train": {
3
- "log_interval": 200,
4
- "eval_interval": 1000,
5
- "seed": 42,
6
- "epochs": 1000,
7
- "learning_rate": 0.0002,
8
- "betas": [
9
- 0.8,
10
- 0.99
11
- ],
12
- "eps": 1e-09,
13
- "batch_size": 12,
14
- "fp16_run": false,
15
- "lr_decay": 0.99995,
16
- "segment_size": 16384,
17
- "init_lr_ratio": 1,
18
- "warmup_epochs": 0,
19
- "c_mel": 45,
20
- "c_kl": 1.0,
21
- "skip_optimizer": true
22
- },
23
- "data": {
24
- "training_files": "filelists/train.list",
25
- "validation_files": "filelists/val.list",
26
- "max_wav_value": 32768.0,
27
- "sampling_rate": 44100,
28
- "filter_length": 2048,
29
- "hop_length": 512,
30
- "win_length": 2048,
31
- "n_mel_channels": 128,
32
- "mel_fmin": 0.0,
33
- "mel_fmax": null,
34
- "add_blank": true,
35
- "n_speakers": 700,
36
- "cleaned_text": true,
37
- "spk2id": {
38
- "TalkFlower_CNzh": 0
39
- }
40
- },
41
- "model": {
42
- "use_spk_conditioned_encoder": true,
43
- "use_noise_scaled_mas": true,
44
- "use_mel_posterior_encoder": false,
45
- "use_duration_discriminator": true,
46
- "inter_channels": 192,
47
- "hidden_channels": 192,
48
- "filter_channels": 768,
49
- "n_heads": 2,
50
- "n_layers": 6,
51
- "kernel_size": 3,
52
- "p_dropout": 0.1,
53
- "resblock": "1",
54
- "resblock_kernel_sizes": [
55
- 3,
56
- 7,
57
- 11
58
- ],
59
- "resblock_dilation_sizes": [
60
- [
61
- 1,
62
- 3,
63
- 5
64
- ],
65
- [
66
- 1,
67
- 3,
68
- 5
69
- ],
70
- [
71
- 1,
72
- 3,
73
- 5
74
- ]
75
- ],
76
- "upsample_rates": [
77
- 8,
78
- 8,
79
- 2,
80
- 2,
81
- 2
82
- ],
83
- "upsample_initial_channel": 512,
84
- "upsample_kernel_sizes": [
85
- 16,
86
- 16,
87
- 8,
88
- 2,
89
- 2
90
- ],
91
- "n_layers_q": 3,
92
- "use_spectral_norm": false,
93
- "gin_channels": 256
94
- },
95
- "version": "2.0"
96
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
app.py CHANGED
@@ -7,7 +7,7 @@ from presets import *
7
  with gr.Blocks(css=customCSS) as demo:
8
  exceed_flag = gr.State(value=False)
9
  tmp_string = gr.Textbox(value="", visible=False)
10
- character_area = gr.HTML(get_character_html("你好呀!"), elem_id="character_area")
11
  with gr.Tab("Speak", elem_id="tab-speak"):
12
  speak_input = gr.Textbox(lines=1, label="Talking Flower will say:", elem_classes="wonder-card input_text", elem_id="speak_input")
13
  speak_button = gr.Button("Speak!", elem_id="speak_button", elem_classes="main-button wonder-card")
 
7
  with gr.Blocks(css=customCSS) as demo:
8
  exceed_flag = gr.State(value=False)
9
  tmp_string = gr.Textbox(value="", visible=False)
10
+ character_area = gr.HTML(get_character_html("你好呀!我现在支持多语言了呢!"), elem_id="character_area")
11
  with gr.Tab("Speak", elem_id="tab-speak"):
12
  speak_input = gr.Textbox(lines=1, label="Talking Flower will say:", elem_classes="wonder-card input_text", elem_id="speak_input")
13
  speak_button = gr.Button("Speak!", elem_id="speak_button", elem_classes="main-button wonder-card")
config.yml CHANGED
@@ -4,7 +4,7 @@
4
  # 拟提供通用路径配置,统一存放数据,避免数据放得很乱
5
  # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
6
  # 不填或者填空则路径为相对于项目根目录的路径
7
- dataset_path: "Data/TalkFlower_CNzh"
8
 
9
  # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
10
  mirror: ""
@@ -34,7 +34,7 @@ preprocess_text:
34
  # 验证集路径
35
  val_path: "filelists/val.list"
36
  # 配置文件路径
37
- config_path: "Data/TalkFlower_CNzh/config.json"
38
  # 每个speaker的验证集条数
39
  val_per_spk: 5
40
  # 验证集最大条数,多于的会被截断并放到训练集中
@@ -47,12 +47,12 @@ preprocess_text:
47
  # 注意, “:” 后需要加空格
48
  bert_gen:
49
  # 训练数据集配置文件路径
50
- config_path: "Data/TalkFlower_CNzh/config.json"
51
  # 并行数
52
  num_processes: 8
53
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
54
  # 该选项同时决定了get_bert_feature的默认设备
55
- device: "cuda"
56
  # 使用多卡推理
57
  use_multi_device: false
58
 
@@ -60,11 +60,11 @@ bert_gen:
60
  # 注意, “:” 后需要加空格
61
  emo_gen:
62
  # 训练数据集配置文件路径
63
- config_path: "Data/TalkFlower_CNzh/config.json"
64
  # 并行数
65
  num_processes: 2
66
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
67
- device: "cuda"
68
 
69
  # train 训练配置
70
  # 注意, “:” 后需要加空格
@@ -85,7 +85,7 @@ train_ms:
85
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
86
  model: "models"
87
  # 配置文件路径
88
- config_path: "config.json"
89
  # 训练使用的worker,不建议超过CPU核心数
90
  num_workers: 16
91
  # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
@@ -100,9 +100,9 @@ webui:
100
  # 推理设备
101
  device: "cpu"
102
  # 模型路径
103
- model: "../../models/G_48000.pth"
104
  # 配置文件路径
105
- config_path: "config.json"
106
  # 端口号
107
  port: 7860
108
  # 是否公开部署,对外网开放
@@ -120,16 +120,16 @@ server:
120
  # 端口号
121
  port: 5000
122
  # 模型默认使用设备:但是当前并没有实现这个配置。
123
- device: "cuda"
124
  # 需要加载的所有模型的配置
125
  # 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
126
  models:
127
  - # 模型的路径
128
- model: "models/G_48000.pth"
129
  # 模型config.json的路径
130
- config: "TalkFlower_CNzh/config.json"
131
  # 模型使用设备,若填写则会覆盖默认配置
132
- device: "cuda"
133
  # 模型默认使用的语言
134
  language: "ZH"
135
  # 模型人物默认参数
 
4
  # 拟提供通用路径配置,统一存放数据,避免数据放得很乱
5
  # 每个数据集与其对应的模型存放至统一路径下,后续所有的路径配置均为相对于datasetPath的路径
6
  # 不填或者填空则路径为相对于项目根目录的路径
7
+ dataset_path: ""
8
 
9
  # 模型镜像源,默认huggingface,使用openi镜像源需指定openi_token
10
  mirror: ""
 
34
  # 验证集路径
35
  val_path: "filelists/val.list"
36
  # 配置文件路径
37
+ config_path: "Data/config.json"
38
  # 每个speaker的验证集条数
39
  val_per_spk: 5
40
  # 验证集最大条数,多于的会被截断并放到训练集中
 
47
  # 注意, “:” 后需要加空格
48
  bert_gen:
49
  # 训练数据集配置文件路径
50
+ config_path: "Data/config.json"
51
  # 并行数
52
  num_processes: 8
53
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
54
  # 该选项同时决定了get_bert_feature的默认设备
55
+ device: "cpu"
56
  # 使用多卡推理
57
  use_multi_device: false
58
 
 
60
  # 注意, “:” 后需要加空格
61
  emo_gen:
62
  # 训练数据集配置文件路径
63
+ config_path: "Data/config.json"
64
  # 并行数
65
  num_processes: 2
66
  # 使用设备:可选项 "cuda" 显卡推理,"cpu" cpu推理
67
+ device: "cpu"
68
 
69
  # train 训练配置
70
  # 注意, “:” 后需要加空格
 
85
  # 训练模型存储目录:与旧版本的区别,原先数据集是存放在logs/model_name下的,现在改为统一存放在Data/你的数据集/models下
86
  model: "models"
87
  # 配置文件路径
88
+ config_path: "Data/config.json"
89
  # 训练使用的worker,不建议超过CPU核心数
90
  num_workers: 16
91
  # 关闭此项可以节约接近50%的磁盘空间,但是可能导致实际训练速度变慢和更高的CPU使用率。
 
100
  # 推理设备
101
  device: "cpu"
102
  # 模型路径
103
+ model: "models/G_multilingual.pth"
104
  # 配置文件路径
105
+ config_path: "Data/config.json"
106
  # 端口号
107
  port: 7860
108
  # 是否公开部署,对外网开放
 
120
  # 端口号
121
  port: 5000
122
  # 模型默认使用设备:但是当前并没有实现这个配置。
123
+ device: "cpu"
124
  # 需要加载的所有模型的配置
125
  # 注意,所有模型都必须正确配置model与config的路径,空路径会导致加载错误。
126
  models:
127
  - # 模型的路径
128
+ model: "models/G_multilingual.pth"
129
  # 模型config.json的路径
130
+ config: "Data/config.json"
131
  # 模型使用设备,若填写则会覆盖默认配置
132
+ device: "cpu"
133
  # 模型默认使用的语言
134
  language: "ZH"
135
  # 模型人物默认参数
emo_gen.py CHANGED
@@ -1,19 +1,21 @@
 
 
 
 
 
 
1
  import torch
2
  import torch.nn as nn
3
- from torch.utils.data import Dataset
4
- from torch.utils.data import DataLoader
5
  from transformers import Wav2Vec2Processor
6
  from transformers.models.wav2vec2.modeling_wav2vec2 import (
7
  Wav2Vec2Model,
8
  Wav2Vec2PreTrainedModel,
9
  )
10
- import librosa
11
- import numpy as np
12
- import argparse
13
- from config import config
14
  import utils
15
- import os
16
- from tqdm import tqdm
17
 
18
 
19
  class RegressionHead(nn.Module):
@@ -78,11 +80,6 @@ class AudioDataset(Dataset):
78
  return torch.from_numpy(processed_data)
79
 
80
 
81
- model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
82
- processor = Wav2Vec2Processor.from_pretrained(model_name)
83
- model = EmotionModel.from_pretrained(model_name)
84
-
85
-
86
  def process_func(
87
  x: np.ndarray,
88
  sampling_rate: int,
@@ -135,16 +132,12 @@ if __name__ == "__main__":
135
  device = config.bert_gen_config.device
136
 
137
  model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
138
- processor = (
139
- Wav2Vec2Processor.from_pretrained(model_name)
140
- if processor is None
141
- else processor
142
- )
143
- model = (
144
- EmotionModel.from_pretrained(model_name).to(device)
145
- if model is None
146
- else model.to(device)
147
- )
148
 
149
  lines = []
150
  with open(hps.data.training_files, encoding="utf-8") as f:
 
1
+ import argparse
2
+ import os
3
+ from pathlib import Path
4
+
5
+ import librosa
6
+ import numpy as np
7
  import torch
8
  import torch.nn as nn
9
+ from torch.utils.data import DataLoader, Dataset
10
+ from tqdm import tqdm
11
  from transformers import Wav2Vec2Processor
12
  from transformers.models.wav2vec2.modeling_wav2vec2 import (
13
  Wav2Vec2Model,
14
  Wav2Vec2PreTrainedModel,
15
  )
16
+
 
 
 
17
  import utils
18
+ from config import config
 
19
 
20
 
21
  class RegressionHead(nn.Module):
 
80
  return torch.from_numpy(processed_data)
81
 
82
 
 
 
 
 
 
83
  def process_func(
84
  x: np.ndarray,
85
  sampling_rate: int,
 
132
  device = config.bert_gen_config.device
133
 
134
  model_name = "./emotional/wav2vec2-large-robust-12-ft-emotion-msp-dim"
135
+ REPO_ID = "audeering/wav2vec2-large-robust-12-ft-emotion-msp-dim"
136
+ if not Path(model_name).joinpath("pytorch_model.bin").exists():
137
+ utils.download_emo_models(config.mirror, model_name, REPO_ID)
138
+
139
+ processor = Wav2Vec2Processor.from_pretrained(model_name)
140
+ model = EmotionModel.from_pretrained(model_name).to(device)
 
 
 
 
141
 
142
  lines = []
143
  with open(hps.data.training_files, encoding="utf-8") as f:
infer.py CHANGED
@@ -29,7 +29,7 @@ from oldVersion.V101.text import symbols as V101symbols
29
  from oldVersion import V111, V110, V101, V200
30
 
31
  # 当前版本信息
32
- latest_version = "2.0"
33
 
34
  # 版本兼容
35
  SynthesizerTrnMap = {
@@ -82,7 +82,7 @@ def get_net_g(model_path: str, version: str, device: str, hps):
82
  return net_g
83
 
84
 
85
- def get_text(text, reference_audio, emotion, language_str, hps, device):
86
  # 在此处实现当前版本的get_text
87
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
88
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
@@ -113,12 +113,6 @@ def get_text(text, reference_audio, emotion, language_str, hps, device):
113
  else:
114
  raise ValueError("language_str should be ZH, JP or EN")
115
 
116
- emo = (
117
- torch.from_numpy(get_emo(reference_audio))
118
- if reference_audio
119
- else torch.Tensor([emotion])
120
- )
121
-
122
  assert bert.shape[-1] == len(
123
  phone
124
  ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
@@ -126,7 +120,16 @@ def get_text(text, reference_audio, emotion, language_str, hps, device):
126
  phone = torch.LongTensor(phone)
127
  tone = torch.LongTensor(tone)
128
  language = torch.LongTensor(language)
129
- return bert, ja_bert, en_bert, emo, phone, tone, language
 
 
 
 
 
 
 
 
 
130
 
131
 
132
  def infer(
@@ -191,9 +194,10 @@ def infer(
191
  device,
192
  )
193
  # 在此处实现当前版本的推理
194
- bert, ja_bert, en_bert, emo, phones, tones, lang_ids = get_text(
195
- text, reference_audio, emotion, language, hps, device
196
  )
 
197
  if skip_start:
198
  phones = phones[1:]
199
  tones = tones[1:]
@@ -261,10 +265,8 @@ def infer_multilang(
261
  skip_start=False,
262
  skip_end=False,
263
  ):
264
- bert, ja_bert, en_bert, emo, phones, tones, lang_ids = [], [], [], [], [], [], []
265
- # bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
266
- # text, language, hps, device
267
- # )
268
  for idx, (txt, lang) in enumerate(zip(text, language)):
269
  skip_start = (idx != 0) or (skip_start and idx == 0)
270
  skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
@@ -272,16 +274,14 @@ def infer_multilang(
272
  temp_bert,
273
  temp_ja_bert,
274
  temp_en_bert,
275
- temp_emo,
276
  temp_phones,
277
  temp_tones,
278
  temp_lang_ids,
279
- ) = get_text(txt, ref, emotion, language, hps, device)
280
  if skip_start:
281
  temp_bert = temp_bert[:, 1:]
282
  temp_ja_bert = temp_ja_bert[:, 1:]
283
  temp_en_bert = temp_en_bert[:, 1:]
284
- temp_emo = temp_emo[:, 1:]
285
  temp_phones = temp_phones[1:]
286
  temp_tones = temp_tones[1:]
287
  temp_lang_ids = temp_lang_ids[1:]
@@ -289,21 +289,18 @@ def infer_multilang(
289
  temp_bert = temp_bert[:, :-1]
290
  temp_ja_bert = temp_ja_bert[:, :-1]
291
  temp_en_bert = temp_en_bert[:, :-1]
292
- temp_emo = temp_emo[:, :-1]
293
  temp_phones = temp_phones[:-1]
294
  temp_tones = temp_tones[:-1]
295
  temp_lang_ids = temp_lang_ids[:-1]
296
  bert.append(temp_bert)
297
  ja_bert.append(temp_ja_bert)
298
  en_bert.append(temp_en_bert)
299
- emo.append(temp_emo)
300
  phones.append(temp_phones)
301
  tones.append(temp_tones)
302
  lang_ids.append(temp_lang_ids)
303
  bert = torch.concatenate(bert, dim=1)
304
  ja_bert = torch.concatenate(ja_bert, dim=1)
305
  en_bert = torch.concatenate(en_bert, dim=1)
306
- emo = torch.concatenate(emo, dim=1)
307
  phones = torch.concatenate(phones, dim=0)
308
  tones = torch.concatenate(tones, dim=0)
309
  lang_ids = torch.concatenate(lang_ids, dim=0)
 
29
  from oldVersion import V111, V110, V101, V200
30
 
31
  # 当前版本信息
32
+ latest_version = "2.1"
33
 
34
  # 版本兼容
35
  SynthesizerTrnMap = {
 
82
  return net_g
83
 
84
 
85
+ def get_text(text, language_str, hps, device):
86
  # 在此处实现当前版本的get_text
87
  norm_text, phone, tone, word2ph = clean_text(text, language_str)
88
  phone, tone, language = cleaned_text_to_sequence(phone, tone, language_str)
 
113
  else:
114
  raise ValueError("language_str should be ZH, JP or EN")
115
 
 
 
 
 
 
 
116
  assert bert.shape[-1] == len(
117
  phone
118
  ), f"Bert seq len {bert.shape[-1]} != {len(phone)}"
 
120
  phone = torch.LongTensor(phone)
121
  tone = torch.LongTensor(tone)
122
  language = torch.LongTensor(language)
123
+ return bert, ja_bert, en_bert, phone, tone, language
124
+
125
+
126
+ def get_emo_(reference_audio, emotion):
127
+ emo = (
128
+ torch.from_numpy(get_emo(reference_audio))
129
+ if reference_audio
130
+ else torch.Tensor([emotion])
131
+ )
132
+ return emo
133
 
134
 
135
  def infer(
 
194
  device,
195
  )
196
  # 在此处实现当前版本的推理
197
+ bert, ja_bert, en_bert, phones, tones, lang_ids = get_text(
198
+ text, language, hps, device
199
  )
200
+ emo = get_emo_(reference_audio, emotion)
201
  if skip_start:
202
  phones = phones[1:]
203
  tones = tones[1:]
 
265
  skip_start=False,
266
  skip_end=False,
267
  ):
268
+ bert, ja_bert, en_bert, phones, tones, lang_ids = [], [], [], [], [], []
269
+ emo = get_emo_(reference_audio, emotion)
 
 
270
  for idx, (txt, lang) in enumerate(zip(text, language)):
271
  skip_start = (idx != 0) or (skip_start and idx == 0)
272
  skip_end = (idx != len(text) - 1) or (skip_end and idx == len(text) - 1)
 
274
  temp_bert,
275
  temp_ja_bert,
276
  temp_en_bert,
 
277
  temp_phones,
278
  temp_tones,
279
  temp_lang_ids,
280
+ ) = get_text(txt, lang, hps, device)
281
  if skip_start:
282
  temp_bert = temp_bert[:, 1:]
283
  temp_ja_bert = temp_ja_bert[:, 1:]
284
  temp_en_bert = temp_en_bert[:, 1:]
 
285
  temp_phones = temp_phones[1:]
286
  temp_tones = temp_tones[1:]
287
  temp_lang_ids = temp_lang_ids[1:]
 
289
  temp_bert = temp_bert[:, :-1]
290
  temp_ja_bert = temp_ja_bert[:, :-1]
291
  temp_en_bert = temp_en_bert[:, :-1]
 
292
  temp_phones = temp_phones[:-1]
293
  temp_tones = temp_tones[:-1]
294
  temp_lang_ids = temp_lang_ids[:-1]
295
  bert.append(temp_bert)
296
  ja_bert.append(temp_ja_bert)
297
  en_bert.append(temp_en_bert)
 
298
  phones.append(temp_phones)
299
  tones.append(temp_tones)
300
  lang_ids.append(temp_lang_ids)
301
  bert = torch.concatenate(bert, dim=1)
302
  ja_bert = torch.concatenate(ja_bert, dim=1)
303
  en_bert = torch.concatenate(en_bert, dim=1)
 
304
  phones = torch.concatenate(phones, dim=0)
305
  tones = torch.concatenate(tones, dim=0)
306
  lang_ids = torch.concatenate(lang_ids, dim=0)
presets.py CHANGED
@@ -4,10 +4,11 @@ import numpy as np
4
  import torch
5
  import re_matching
6
  import utils
7
- from infer import infer, latest_version, get_net_g
8
  import gradio as gr
9
  from config import config
10
  from tools.webui import reload_javascript, get_character_html
 
11
 
12
  logging.basicConfig(
13
  level=logging.INFO,
@@ -42,6 +43,7 @@ def speak_fn(
42
  interval_between_para=0.2, # 段间间隔
43
  interval_between_sent=1, # 句间间隔
44
  ):
 
45
  while text.find("\n\n") != -1:
46
  text = text.replace("\n\n", "\n")
47
  if len(text) > 100:
@@ -54,58 +56,113 @@ def speak_fn(
54
  audio_value = "./assets/audios/overlength.wav"
55
  exceed_flag = not exceed_flag
56
  else:
57
- audio_list = []
58
- if len(text) > 42:
59
- logging.info(f"Long Text: {text}")
60
- para_list = re_matching.cut_para(text)
61
- for p in para_list:
62
- audio_list_sent = []
63
- sent_list = re_matching.cut_sent(p)
64
- for s in sent_list:
65
- audio = infer(
66
- s,
67
- sdp_ratio=sdp_ratio,
68
- noise_scale=noise_scale,
69
- noise_scale_w=noise_scale_w,
70
- length_scale=length_scale,
71
- sid=speaker,
72
- language=language,
73
- hps=hps,
74
- net_g=net_g,
75
- device=device,
76
- reference_audio=reference_audio,
77
- emotion=emotion,
78
- )
79
- audio_list_sent.append(audio)
80
- silence = np.zeros((int)(44100 * interval_between_sent))
81
- audio_list_sent.append(silence)
82
- if (interval_between_para - interval_between_sent) > 0:
83
- silence = np.zeros((int)(44100 * (interval_between_para - interval_between_sent)))
84
- audio_list_sent.append(silence)
85
- audio16bit = gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_list_sent)) # 对完整句子做音量归一
86
- audio_list.append(audio16bit)
87
- else:
88
- logging.info(f"Short Text: {text}")
89
- silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
90
- with torch.no_grad():
91
- for piece in text.split("|"):
92
- audio = infer(
93
- piece,
94
- sdp_ratio=sdp_ratio,
95
- noise_scale=noise_scale,
96
- noise_scale_w=noise_scale_w,
97
- length_scale=length_scale,
98
- sid=speaker,
99
- language=language,
100
- hps=hps,
101
- net_g=net_g,
102
- device=device,
103
- reference_audio=reference_audio,
104
- emotion=emotion,
105
- )
106
- audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
107
- audio_list.append(audio16bit)
108
- audio_list.append(silence) # 将静音添加到列表中
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
109
 
110
  audio_concat = np.concatenate(audio_list)
111
  audio_value = (hps.data.sampling_rate, audio_concat)
@@ -113,13 +170,15 @@ def speak_fn(
113
  return gr.update(value=audio_value, autoplay=True), get_character_html(text), exceed_flag, gr.update(interactive=True)
114
 
115
 
 
116
  def submit_lock_fn():
117
  return gr.update(interactive=False)
118
 
119
 
120
  def init_fn():
121
- gr.Info("2023-11-24: 优化长句生成效果;增加示例;更新了一些小彩蛋;画了一些大饼)")
122
- gr.Info("Only support Chinese now. Trying to train a mutilingual model. 欢迎在 Community 中提建议~")
 
123
 
124
  index = random.randint(1,7)
125
  welcome_text = get_sentence("Welcome", index)
 
4
  import torch
5
  import re_matching
6
  import utils
7
+ from infer import infer, latest_version, get_net_g, infer_multilang
8
  import gradio as gr
9
  from config import config
10
  from tools.webui import reload_javascript, get_character_html
11
+ from tools.sentence import split_by_language
12
 
13
  logging.basicConfig(
14
  level=logging.INFO,
 
43
  interval_between_para=0.2, # 段间间隔
44
  interval_between_sent=1, # 句间间隔
45
  ):
46
+ audio_list = []
47
  while text.find("\n\n") != -1:
48
  text = text.replace("\n\n", "\n")
49
  if len(text) > 100:
 
56
  audio_value = "./assets/audios/overlength.wav"
57
  exceed_flag = not exceed_flag
58
  else:
59
+ for idx, slice in enumerate(text.split("|")):
60
+ if slice == "":
61
+ continue
62
+ skip_start = idx != 0
63
+ skip_end = idx != len(text.split("|")) - 1
64
+ sentences_list = split_by_language(
65
+ slice, target_languages=["zh", "ja", "en"]
66
+ )
67
+ idx = 0
68
+ while idx < len(sentences_list):
69
+ text_to_generate = []
70
+ lang_to_generate = []
71
+ while True:
72
+ content, lang = sentences_list[idx]
73
+ temp_text = [content]
74
+ lang = lang.upper()
75
+ if lang == "JA":
76
+ lang = "JP"
77
+ if len(text_to_generate) > 0:
78
+ text_to_generate[-1] += [temp_text.pop(0)]
79
+ lang_to_generate[-1] += [lang]
80
+ if len(temp_text) > 0:
81
+ text_to_generate += [[i] for i in temp_text]
82
+ lang_to_generate += [[lang]] * len(temp_text)
83
+ if idx + 1 < len(sentences_list):
84
+ idx += 1
85
+ else:
86
+ break
87
+ skip_start = (idx != 0) and skip_start
88
+ skip_end = (idx != len(sentences_list) - 1) and skip_end
89
+ print(text_to_generate, lang_to_generate)
90
+
91
+ with torch.no_grad():
92
+ for i, piece in enumerate(text_to_generate):
93
+ skip_start = (i != 0) and skip_start
94
+ skip_end = (i != len(text_to_generate) - 1) and skip_end
95
+ audio = infer_multilang(
96
+ piece,
97
+ reference_audio=reference_audio,
98
+ emotion=emotion,
99
+ sdp_ratio=sdp_ratio,
100
+ noise_scale=noise_scale,
101
+ noise_scale_w=noise_scale_w,
102
+ length_scale=length_scale,
103
+ sid=speaker,
104
+ language=lang_to_generate[i],
105
+ hps=hps,
106
+ net_g=net_g,
107
+ device=device,
108
+ skip_start=skip_start,
109
+ skip_end=skip_end,
110
+ )
111
+ audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
112
+ audio_list.append(audio16bit)
113
+ idx += 1
114
+ # 单一语言推理
115
+ # if len(text) > 42:
116
+ # logging.info(f"Long Text: {text}")
117
+ # para_list = re_matching.cut_para(text)
118
+ # for p in para_list:
119
+ # audio_list_sent = []
120
+ # sent_list = re_matching.cut_sent(p)
121
+ # for s in sent_list:
122
+ # audio = infer(
123
+ # s,
124
+ # sdp_ratio=sdp_ratio,
125
+ # noise_scale=noise_scale,
126
+ # noise_scale_w=noise_scale_w,
127
+ # length_scale=length_scale,
128
+ # sid=speaker,
129
+ # language=language,
130
+ # hps=hps,
131
+ # net_g=net_g,
132
+ # device=device,
133
+ # reference_audio=reference_audio,
134
+ # emotion=emotion,
135
+ # )
136
+ # audio_list_sent.append(audio)
137
+ # silence = np.zeros((int)(44100 * interval_between_sent))
138
+ # audio_list_sent.append(silence)
139
+ # if (interval_between_para - interval_between_sent) > 0:
140
+ # silence = np.zeros((int)(44100 * (interval_between_para - interval_between_sent)))
141
+ # audio_list_sent.append(silence)
142
+ # audio16bit = gr.processing_utils.convert_to_16_bit_wav(np.concatenate(audio_list_sent)) # 对完整句子做音量归一
143
+ # audio_list.append(audio16bit)
144
+ # else:
145
+ # logging.info(f"Short Text: {text}")
146
+ # silence = np.zeros(hps.data.sampling_rate // 2, dtype=np.int16)
147
+ # with torch.no_grad():
148
+ # for piece in text.split("|"):
149
+ # audio = infer(
150
+ # piece,
151
+ # sdp_ratio=sdp_ratio,
152
+ # noise_scale=noise_scale,
153
+ # noise_scale_w=noise_scale_w,
154
+ # length_scale=length_scale,
155
+ # sid=speaker,
156
+ # language=language,
157
+ # hps=hps,
158
+ # net_g=net_g,
159
+ # device=device,
160
+ # reference_audio=reference_audio,
161
+ # emotion=emotion,
162
+ # )
163
+ # audio16bit = gr.processing_utils.convert_to_16_bit_wav(audio)
164
+ # audio_list.append(audio16bit)
165
+ # audio_list.append(silence) # 将静音添加到列表中
166
 
167
  audio_concat = np.concatenate(audio_list)
168
  audio_value = (hps.data.sampling_rate, audio_concat)
 
170
  return gr.update(value=audio_value, autoplay=True), get_character_html(text), exceed_flag, gr.update(interactive=True)
171
 
172
 
173
+
174
  def submit_lock_fn():
175
  return gr.update(interactive=False)
176
 
177
 
178
  def init_fn():
179
+ gr.Info("2023-11-28: 支持多语言啦!闲聊花花现在能说中、英、日语啦!")
180
+ # gr.Info("2023-11-24: 优化长句生成效果;增加示例;更新了一些小彩蛋;画了一些大饼)")
181
+ gr.Info("Support languages: ZH|EN|JA. 欢迎在 Community 中提建议~")
182
 
183
  index = random.randint(1,7)
184
  welcome_text = get_sentence("Welcome", index)
utils.py CHANGED
@@ -9,12 +9,31 @@ import numpy as np
9
  from huggingface_hub import hf_hub_download
10
  from scipy.io.wavfile import read
11
  import torch
 
12
 
13
  MATPLOTLIB_FLAG = False
14
 
15
  logger = logging.getLogger(__name__)
16
 
17
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
  def download_checkpoint(
19
  dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
20
  ):
@@ -385,3 +404,54 @@ class HParams:
385
 
386
  def __repr__(self):
387
  return self.__dict__.__repr__()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9
  from huggingface_hub import hf_hub_download
10
  from scipy.io.wavfile import read
11
  import torch
12
+ import re
13
 
14
  MATPLOTLIB_FLAG = False
15
 
16
  logger = logging.getLogger(__name__)
17
 
18
 
19
+ def download_emo_models(mirror, repo_id, model_name):
20
+ if mirror == "openi":
21
+ import openi
22
+
23
+ openi.model.download_model(
24
+ "Stardust_minus/Bert-VITS2",
25
+ repo_id.split("/")[-1],
26
+ "./emotional",
27
+ )
28
+ else:
29
+ hf_hub_download(
30
+ repo_id,
31
+ "pytorch_model.bin",
32
+ local_dir=model_name,
33
+ local_dir_use_symlinks=False,
34
+ )
35
+
36
+
37
  def download_checkpoint(
38
  dir_path, repo_config, token=None, regex="G_*.pth", mirror="openi"
39
  ):
 
404
 
405
  def __repr__(self):
406
  return self.__dict__.__repr__()
407
+
408
+
409
+ def load_model(model_path, config_path):
410
+ hps = get_hparams_from_file(config_path)
411
+ net = SynthesizerTrn(
412
+ # len(symbols),
413
+ 108,
414
+ hps.data.filter_length // 2 + 1,
415
+ hps.train.segment_size // hps.data.hop_length,
416
+ n_speakers=hps.data.n_speakers,
417
+ **hps.model,
418
+ ).to("cpu")
419
+ _ = net.eval()
420
+ _ = load_checkpoint(model_path, net, None, skip_optimizer=True)
421
+ return net
422
+
423
+
424
+ def mix_model(
425
+ network1, network2, output_path, voice_ratio=(0.5, 0.5), tone_ratio=(0.5, 0.5)
426
+ ):
427
+ if hasattr(network1, "module"):
428
+ state_dict1 = network1.module.state_dict()
429
+ state_dict2 = network2.module.state_dict()
430
+ else:
431
+ state_dict1 = network1.state_dict()
432
+ state_dict2 = network2.state_dict()
433
+ for k in state_dict1.keys():
434
+ if k not in state_dict2.keys():
435
+ continue
436
+ if "enc_p" in k:
437
+ state_dict1[k] = (
438
+ state_dict1[k].clone() * tone_ratio[0]
439
+ + state_dict2[k].clone() * tone_ratio[1]
440
+ )
441
+ else:
442
+ state_dict1[k] = (
443
+ state_dict1[k].clone() * voice_ratio[0]
444
+ + state_dict2[k].clone() * voice_ratio[1]
445
+ )
446
+ for k in state_dict2.keys():
447
+ if k not in state_dict1.keys():
448
+ state_dict1[k] = state_dict2[k].clone()
449
+ torch.save(
450
+ {"model": state_dict1, "iteration": 0, "optimizer": None, "learning_rate": 0},
451
+ output_path,
452
+ )
453
+
454
+
455
+ def get_steps(model_path):
456
+ matches = re.findall(r"\d+", model_path)
457
+ return matches[-1] if matches else None