Spaces:
Running
on
Zero
Running
on
Zero
fix lint
Browse files- .github/workflows/lint.yml +1 -1
- cosyvoice/bin/export_jit.py +4 -1
- cosyvoice/bin/export_onnx.py +7 -4
- cosyvoice/bin/inference.py +4 -6
- cosyvoice/bin/train.py +1 -0
- cosyvoice/cli/cosyvoice.py +3 -2
- cosyvoice/cli/frontend.py +16 -11
- cosyvoice/cli/model.py +34 -33
- cosyvoice/dataset/dataset.py +1 -1
- cosyvoice/dataset/processor.py +2 -1
- cosyvoice/flow/decoder.py +1 -2
- cosyvoice/flow/flow.py +7 -2
- cosyvoice/flow/flow_matching.py +1 -0
- cosyvoice/flow/length_regulator.py +0 -0
- cosyvoice/hifigan/f0_predictor.py +0 -0
- cosyvoice/hifigan/generator.py +7 -4
- cosyvoice/llm/llm.py +10 -5
- cosyvoice/transformer/embedding.py +2 -2
- cosyvoice/utils/common.py +6 -1
- cosyvoice/utils/executor.py +2 -1
- cosyvoice/utils/file_utils.py +3 -0
- cosyvoice/utils/frontend_utils.py +1 -0
- cosyvoice/utils/scheduler.py +1 -2
- cosyvoice/utils/train_utils.py +2 -2
- examples/libritts/cosyvoice/local/prepare_data.py +2 -0
- examples/libritts/cosyvoice/run.sh +1 -1
- examples/magicdata-read/cosyvoice/local/prepare_data.py +2 -0
- examples/magicdata-read/cosyvoice/run.sh +1 -1
- runtime/python/fastapi/client.py +4 -2
- runtime/python/fastapi/server.py +11 -5
- runtime/python/grpc/client.py +2 -1
- runtime/python/grpc/server.py +11 -5
- tools/extract_embedding.py +1 -0
- tools/make_parquet_list.py +1 -0
- webui.py +37 -26
.github/workflows/lint.yml
CHANGED
@@ -51,5 +51,5 @@ jobs:
|
|
51 |
set -eux
|
52 |
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
53 |
flake8 --version
|
54 |
-
flake8 --max-line-length
|
55 |
if [ $? != 0 ]; then exit 1; fi
|
|
|
51 |
set -eux
|
52 |
pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
|
53 |
flake8 --version
|
54 |
+
flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
|
55 |
if [ $? != 0 ]; then exit 1; fi
|
cosyvoice/bin/export_jit.py
CHANGED
@@ -19,12 +19,13 @@ import logging
|
|
19 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
import os
|
21 |
import sys
|
|
|
22 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
23 |
sys.path.append('{}/../..'.format(ROOT_DIR))
|
24 |
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
25 |
-
import torch
|
26 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
27 |
|
|
|
28 |
def get_args():
|
29 |
parser = argparse.ArgumentParser(description='export your model for deployment')
|
30 |
parser.add_argument('--model_dir',
|
@@ -35,6 +36,7 @@ def get_args():
|
|
35 |
print(args)
|
36 |
return args
|
37 |
|
|
|
38 |
def main():
|
39 |
args = get_args()
|
40 |
logging.basicConfig(level=logging.DEBUG,
|
@@ -67,5 +69,6 @@ def main():
|
|
67 |
script = torch.jit.optimize_for_inference(script)
|
68 |
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
69 |
|
|
|
70 |
if __name__ == '__main__':
|
71 |
main()
|
|
|
19 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
import os
|
21 |
import sys
|
22 |
+
import torch
|
23 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
sys.path.append('{}/../..'.format(ROOT_DIR))
|
25 |
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
|
|
26 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
27 |
|
28 |
+
|
29 |
def get_args():
|
30 |
parser = argparse.ArgumentParser(description='export your model for deployment')
|
31 |
parser.add_argument('--model_dir',
|
|
|
36 |
print(args)
|
37 |
return args
|
38 |
|
39 |
+
|
40 |
def main():
|
41 |
args = get_args()
|
42 |
logging.basicConfig(level=logging.DEBUG,
|
|
|
69 |
script = torch.jit.optimize_for_inference(script)
|
70 |
script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
|
71 |
|
72 |
+
|
73 |
if __name__ == '__main__':
|
74 |
main()
|
cosyvoice/bin/export_onnx.py
CHANGED
@@ -20,13 +20,13 @@ import logging
|
|
20 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
21 |
import os
|
22 |
import sys
|
23 |
-
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
-
sys.path.append('{}/../..'.format(ROOT_DIR))
|
25 |
-
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
26 |
import onnxruntime
|
27 |
import random
|
28 |
import torch
|
29 |
from tqdm import tqdm
|
|
|
|
|
|
|
30 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
31 |
|
32 |
|
@@ -50,6 +50,7 @@ def get_args():
|
|
50 |
print(args)
|
51 |
return args
|
52 |
|
|
|
53 |
def main():
|
54 |
args = get_args()
|
55 |
logging.basicConfig(level=logging.DEBUG,
|
@@ -89,7 +90,8 @@ def main():
|
|
89 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
90 |
option.intra_op_num_threads = 1
|
91 |
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
92 |
-
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
|
|
93 |
|
94 |
for _ in tqdm(range(10)):
|
95 |
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
@@ -105,5 +107,6 @@ def main():
|
|
105 |
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
106 |
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
107 |
|
|
|
108 |
if __name__ == "__main__":
|
109 |
main()
|
|
|
20 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
21 |
import os
|
22 |
import sys
|
|
|
|
|
|
|
23 |
import onnxruntime
|
24 |
import random
|
25 |
import torch
|
26 |
from tqdm import tqdm
|
27 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
28 |
+
sys.path.append('{}/../..'.format(ROOT_DIR))
|
29 |
+
sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
30 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
31 |
|
32 |
|
|
|
50 |
print(args)
|
51 |
return args
|
52 |
|
53 |
+
|
54 |
def main():
|
55 |
args = get_args()
|
56 |
logging.basicConfig(level=logging.DEBUG,
|
|
|
90 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
91 |
option.intra_op_num_threads = 1
|
92 |
providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
|
93 |
+
estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
|
94 |
+
sess_options=option, providers=providers)
|
95 |
|
96 |
for _ in tqdm(range(10)):
|
97 |
x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
|
|
|
107 |
output_onnx = estimator_onnx.run(None, ort_inputs)[0]
|
108 |
torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
|
109 |
|
110 |
+
|
111 |
if __name__ == "__main__":
|
112 |
main()
|
cosyvoice/bin/inference.py
CHANGED
@@ -18,16 +18,15 @@ import argparse
|
|
18 |
import logging
|
19 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
import os
|
21 |
-
|
22 |
import torch
|
23 |
from torch.utils.data import DataLoader
|
24 |
import torchaudio
|
25 |
from hyperpyyaml import load_hyperpyyaml
|
26 |
from tqdm import tqdm
|
27 |
from cosyvoice.cli.model import CosyVoiceModel
|
28 |
-
|
29 |
from cosyvoice.dataset.dataset import Dataset
|
30 |
|
|
|
31 |
def get_args():
|
32 |
parser = argparse.ArgumentParser(description='inference with your model')
|
33 |
parser.add_argument('--config', required=True, help='config file')
|
@@ -66,7 +65,8 @@ def main():
|
|
66 |
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
67 |
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
68 |
|
69 |
-
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
|
|
70 |
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
71 |
|
72 |
del configs
|
@@ -74,13 +74,11 @@ def main():
|
|
74 |
fn = os.path.join(args.result_dir, 'wav.scp')
|
75 |
f = open(fn, 'w')
|
76 |
with torch.no_grad():
|
77 |
-
for
|
78 |
utts = batch["utts"]
|
79 |
assert len(utts) == 1, "inference mode only support batchsize 1"
|
80 |
-
text = batch["text"]
|
81 |
text_token = batch["text_token"].to(device)
|
82 |
text_token_len = batch["text_token_len"].to(device)
|
83 |
-
tts_text = batch["tts_text"]
|
84 |
tts_index = batch["tts_index"]
|
85 |
tts_text_token = batch["tts_text_token"].to(device)
|
86 |
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
|
|
18 |
import logging
|
19 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
20 |
import os
|
|
|
21 |
import torch
|
22 |
from torch.utils.data import DataLoader
|
23 |
import torchaudio
|
24 |
from hyperpyyaml import load_hyperpyyaml
|
25 |
from tqdm import tqdm
|
26 |
from cosyvoice.cli.model import CosyVoiceModel
|
|
|
27 |
from cosyvoice.dataset.dataset import Dataset
|
28 |
|
29 |
+
|
30 |
def get_args():
|
31 |
parser = argparse.ArgumentParser(description='inference with your model')
|
32 |
parser.add_argument('--config', required=True, help='config file')
|
|
|
65 |
model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
|
66 |
model.load(args.llm_model, args.flow_model, args.hifigan_model)
|
67 |
|
68 |
+
test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
|
69 |
+
tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
|
70 |
test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
|
71 |
|
72 |
del configs
|
|
|
74 |
fn = os.path.join(args.result_dir, 'wav.scp')
|
75 |
f = open(fn, 'w')
|
76 |
with torch.no_grad():
|
77 |
+
for _, batch in tqdm(enumerate(test_data_loader)):
|
78 |
utts = batch["utts"]
|
79 |
assert len(utts) == 1, "inference mode only support batchsize 1"
|
|
|
80 |
text_token = batch["text_token"].to(device)
|
81 |
text_token_len = batch["text_token_len"].to(device)
|
|
|
82 |
tts_index = batch["tts_index"]
|
83 |
tts_text_token = batch["tts_text_token"].to(device)
|
84 |
tts_text_token_len = batch["tts_text_token_len"].to(device)
|
cosyvoice/bin/train.py
CHANGED
@@ -132,5 +132,6 @@ def main():
|
|
132 |
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
133 |
dist.destroy_process_group(group_join)
|
134 |
|
|
|
135 |
if __name__ == '__main__':
|
136 |
main()
|
|
|
132 |
executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
|
133 |
dist.destroy_process_group(group_join)
|
134 |
|
135 |
+
|
136 |
if __name__ == '__main__':
|
137 |
main()
|
cosyvoice/cli/cosyvoice.py
CHANGED
@@ -20,6 +20,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
|
|
20 |
from cosyvoice.cli.model import CosyVoiceModel
|
21 |
from cosyvoice.utils.file_utils import logging
|
22 |
|
|
|
23 |
class CosyVoice:
|
24 |
|
25 |
def __init__(self, model_dir, load_jit=True, load_onnx=True):
|
@@ -42,8 +43,8 @@ class CosyVoice:
|
|
42 |
'{}/hift.pt'.format(model_dir))
|
43 |
if load_jit:
|
44 |
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
45 |
-
|
46 |
-
|
47 |
if load_onnx:
|
48 |
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
49 |
del configs
|
|
|
20 |
from cosyvoice.cli.model import CosyVoiceModel
|
21 |
from cosyvoice.utils.file_utils import logging
|
22 |
|
23 |
+
|
24 |
class CosyVoice:
|
25 |
|
26 |
def __init__(self, model_dir, load_jit=True, load_onnx=True):
|
|
|
43 |
'{}/hift.pt'.format(model_dir))
|
44 |
if load_jit:
|
45 |
self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
|
46 |
+
'{}/llm.llm.fp16.zip'.format(model_dir),
|
47 |
+
'{}/flow.encoder.fp32.zip'.format(model_dir))
|
48 |
if load_onnx:
|
49 |
self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
|
50 |
del configs
|
cosyvoice/cli/frontend.py
CHANGED
@@ -50,7 +50,9 @@ class CosyVoiceFrontEnd:
|
|
50 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
51 |
option.intra_op_num_threads = 1
|
52 |
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
53 |
-
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
|
|
|
|
54 |
if os.path.exists(spk2info):
|
55 |
self.spk2info = torch.load(spk2info, map_location=self.device)
|
56 |
self.instruct = instruct
|
@@ -60,7 +62,8 @@ class CosyVoiceFrontEnd:
|
|
60 |
if self.use_ttsfrd:
|
61 |
self.frd = ttsfrd.TtsFrontendEngine()
|
62 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
63 |
-
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True,
|
|
|
64 |
self.frd.set_lang_type('pinyin')
|
65 |
self.frd.enable_pinyin_mix(True)
|
66 |
self.frd.set_breakmodel_index(1)
|
@@ -76,8 +79,11 @@ class CosyVoiceFrontEnd:
|
|
76 |
|
77 |
def _extract_speech_token(self, speech):
|
78 |
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
79 |
-
speech_token = self.speech_tokenizer_session.run(None,
|
80 |
-
|
|
|
|
|
|
|
81 |
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
82 |
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
83 |
return speech_token, speech_token_len
|
@@ -88,7 +94,8 @@ class CosyVoiceFrontEnd:
|
|
88 |
dither=0,
|
89 |
sample_frequency=16000)
|
90 |
feat = feat - feat.mean(dim=0, keepdim=True)
|
91 |
-
embedding = self.campplus_session.run(None,
|
|
|
92 |
embedding = torch.tensor([embedding]).to(self.device)
|
93 |
return embedding
|
94 |
|
@@ -112,18 +119,16 @@ class CosyVoiceFrontEnd:
|
|
112 |
text = text.replace(" - ", ",")
|
113 |
text = remove_bracket(text)
|
114 |
text = re.sub(r'[,,]+$', '。', text)
|
115 |
-
texts =
|
116 |
-
|
117 |
-
comma_split=False)]
|
118 |
else:
|
119 |
if self.use_ttsfrd:
|
120 |
text = self.frd.get_frd_extra_info(text, 'input')
|
121 |
else:
|
122 |
text = self.en_tn_model.normalize(text)
|
123 |
text = spell_out_number(text, self.inflect_parser)
|
124 |
-
texts =
|
125 |
-
|
126 |
-
comma_split=False)]
|
127 |
if split is False:
|
128 |
return text
|
129 |
return texts
|
|
|
50 |
option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
|
51 |
option.intra_op_num_threads = 1
|
52 |
self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
|
53 |
+
self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
|
54 |
+
providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
|
55 |
+
"CPUExecutionProvider"])
|
56 |
if os.path.exists(spk2info):
|
57 |
self.spk2info = torch.load(spk2info, map_location=self.device)
|
58 |
self.instruct = instruct
|
|
|
62 |
if self.use_ttsfrd:
|
63 |
self.frd = ttsfrd.TtsFrontendEngine()
|
64 |
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
65 |
+
assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
|
66 |
+
'failed to initialize ttsfrd resource'
|
67 |
self.frd.set_lang_type('pinyin')
|
68 |
self.frd.enable_pinyin_mix(True)
|
69 |
self.frd.set_breakmodel_index(1)
|
|
|
79 |
|
80 |
def _extract_speech_token(self, speech):
|
81 |
feat = whisper.log_mel_spectrogram(speech, n_mels=128)
|
82 |
+
speech_token = self.speech_tokenizer_session.run(None,
|
83 |
+
{self.speech_tokenizer_session.get_inputs()[0].name:
|
84 |
+
feat.detach().cpu().numpy(),
|
85 |
+
self.speech_tokenizer_session.get_inputs()[1].name:
|
86 |
+
np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
|
87 |
speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
|
88 |
speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
|
89 |
return speech_token, speech_token_len
|
|
|
94 |
dither=0,
|
95 |
sample_frequency=16000)
|
96 |
feat = feat - feat.mean(dim=0, keepdim=True)
|
97 |
+
embedding = self.campplus_session.run(None,
|
98 |
+
{self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
|
99 |
embedding = torch.tensor([embedding]).to(self.device)
|
100 |
return embedding
|
101 |
|
|
|
119 |
text = text.replace(" - ", ",")
|
120 |
text = remove_bracket(text)
|
121 |
text = re.sub(r'[,,]+$', '。', text)
|
122 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
|
123 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
|
|
124 |
else:
|
125 |
if self.use_ttsfrd:
|
126 |
text = self.frd.get_frd_extra_info(text, 'input')
|
127 |
else:
|
128 |
text = self.en_tn_model.normalize(text)
|
129 |
text = spell_out_number(text, self.inflect_parser)
|
130 |
+
texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
|
131 |
+
token_min_n=60, merge_len=20, comma_split=False))
|
|
|
132 |
if split is False:
|
133 |
return text
|
134 |
return texts
|
cosyvoice/cli/model.py
CHANGED
@@ -18,7 +18,7 @@ import time
|
|
18 |
from contextlib import nullcontext
|
19 |
import uuid
|
20 |
from cosyvoice.utils.common import fade_in_out
|
21 |
-
|
22 |
|
23 |
class CosyVoiceModel:
|
24 |
|
@@ -80,27 +80,27 @@ class CosyVoiceModel:
|
|
80 |
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
81 |
with self.llm_context:
|
82 |
for i in self.llm.inference(text=text.to(self.device),
|
83 |
-
|
84 |
-
|
85 |
-
|
86 |
-
|
87 |
-
|
88 |
-
|
89 |
-
|
90 |
-
|
91 |
-
|
92 |
self.tts_speech_token_dict[uuid].append(i)
|
93 |
self.llm_end_dict[uuid] = True
|
94 |
|
95 |
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
|
96 |
with self.flow_hift_context:
|
97 |
tts_mel = self.flow.inference(token=token.to(self.device),
|
98 |
-
|
99 |
-
|
100 |
-
|
101 |
-
|
102 |
-
|
103 |
-
|
104 |
# mel overlap fade in out
|
105 |
if self.mel_overlap_dict[uuid] is not None:
|
106 |
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
@@ -129,7 +129,8 @@ class CosyVoiceModel:
|
|
129 |
# this_uuid is used to track variables related to this inference thread
|
130 |
this_uuid = str(uuid.uuid1())
|
131 |
with self.lock:
|
132 |
-
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid]
|
|
|
133 |
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
134 |
p.start()
|
135 |
if stream is True:
|
@@ -140,12 +141,12 @@ class CosyVoiceModel:
|
|
140 |
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
|
141 |
with self.flow_hift_context:
|
142 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
143 |
-
|
144 |
-
|
145 |
-
|
146 |
-
|
147 |
-
|
148 |
-
yield
|
149 |
with self.lock:
|
150 |
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
151 |
# increase token_hop_len for better speech quality
|
@@ -157,11 +158,11 @@ class CosyVoiceModel:
|
|
157 |
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
158 |
with self.flow_hift_context:
|
159 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
160 |
-
|
161 |
-
|
162 |
-
|
163 |
-
|
164 |
-
|
165 |
yield {'tts_speech': this_tts_speech.cpu()}
|
166 |
else:
|
167 |
# deal with all tokens
|
@@ -169,11 +170,11 @@ class CosyVoiceModel:
|
|
169 |
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
170 |
with self.flow_hift_context:
|
171 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
172 |
-
|
173 |
-
|
174 |
-
|
175 |
-
|
176 |
-
|
177 |
yield {'tts_speech': this_tts_speech.cpu()}
|
178 |
with self.lock:
|
179 |
self.tts_speech_token_dict.pop(this_uuid)
|
|
|
18 |
from contextlib import nullcontext
|
19 |
import uuid
|
20 |
from cosyvoice.utils.common import fade_in_out
|
21 |
+
|
22 |
|
23 |
class CosyVoiceModel:
|
24 |
|
|
|
80 |
def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
|
81 |
with self.llm_context:
|
82 |
for i in self.llm.inference(text=text.to(self.device),
|
83 |
+
text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
|
84 |
+
prompt_text=prompt_text.to(self.device),
|
85 |
+
prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
|
86 |
+
prompt_speech_token=llm_prompt_speech_token.to(self.device),
|
87 |
+
prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
|
88 |
+
embedding=llm_embedding.to(self.device).half(),
|
89 |
+
sampling=25,
|
90 |
+
max_token_text_ratio=30,
|
91 |
+
min_token_text_ratio=3):
|
92 |
self.tts_speech_token_dict[uuid].append(i)
|
93 |
self.llm_end_dict[uuid] = True
|
94 |
|
95 |
def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
|
96 |
with self.flow_hift_context:
|
97 |
tts_mel = self.flow.inference(token=token.to(self.device),
|
98 |
+
token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
|
99 |
+
prompt_token=prompt_token.to(self.device),
|
100 |
+
prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
|
101 |
+
prompt_feat=prompt_feat.to(self.device),
|
102 |
+
prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
|
103 |
+
embedding=embedding.to(self.device))
|
104 |
# mel overlap fade in out
|
105 |
if self.mel_overlap_dict[uuid] is not None:
|
106 |
tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
|
|
|
129 |
# this_uuid is used to track variables related to this inference thread
|
130 |
this_uuid = str(uuid.uuid1())
|
131 |
with self.lock:
|
132 |
+
self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
|
133 |
+
self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
|
134 |
p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
|
135 |
p.start()
|
136 |
if stream is True:
|
|
|
141 |
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
|
142 |
with self.flow_hift_context:
|
143 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
144 |
+
prompt_token=flow_prompt_speech_token,
|
145 |
+
prompt_feat=prompt_speech_feat,
|
146 |
+
embedding=flow_embedding,
|
147 |
+
uuid=this_uuid,
|
148 |
+
finalize=False)
|
149 |
+
yield {'tts_speech': this_tts_speech.cpu()}
|
150 |
with self.lock:
|
151 |
self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
|
152 |
# increase token_hop_len for better speech quality
|
|
|
158 |
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
159 |
with self.flow_hift_context:
|
160 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
161 |
+
prompt_token=flow_prompt_speech_token,
|
162 |
+
prompt_feat=prompt_speech_feat,
|
163 |
+
embedding=flow_embedding,
|
164 |
+
uuid=this_uuid,
|
165 |
+
finalize=True)
|
166 |
yield {'tts_speech': this_tts_speech.cpu()}
|
167 |
else:
|
168 |
# deal with all tokens
|
|
|
170 |
this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
|
171 |
with self.flow_hift_context:
|
172 |
this_tts_speech = self.token2wav(token=this_tts_speech_token,
|
173 |
+
prompt_token=flow_prompt_speech_token,
|
174 |
+
prompt_feat=prompt_speech_feat,
|
175 |
+
embedding=flow_embedding,
|
176 |
+
uuid=this_uuid,
|
177 |
+
finalize=True)
|
178 |
yield {'tts_speech': this_tts_speech.cpu()}
|
179 |
with self.lock:
|
180 |
self.tts_speech_token_dict.pop(this_uuid)
|
cosyvoice/dataset/dataset.py
CHANGED
@@ -148,7 +148,7 @@ def Dataset(data_list_file,
|
|
148 |
tts_data = json.load(f)
|
149 |
utt2lists = read_json_lists(prompt_utt2data)
|
150 |
# filter unnecessary file in inference mode
|
151 |
-
lists = list(
|
152 |
dataset = DataList(lists,
|
153 |
shuffle=shuffle,
|
154 |
partition=partition)
|
|
|
148 |
tts_data = json.load(f)
|
149 |
utt2lists = read_json_lists(prompt_utt2data)
|
150 |
# filter unnecessary file in inference mode
|
151 |
+
lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
|
152 |
dataset = DataList(lists,
|
153 |
shuffle=shuffle,
|
154 |
partition=partition)
|
cosyvoice/dataset/processor.py
CHANGED
@@ -23,7 +23,7 @@ import torch.nn.functional as F
|
|
23 |
|
24 |
torchaudio.set_audio_backend('soundfile')
|
25 |
|
26 |
-
AUDIO_FORMAT_SETS =
|
27 |
|
28 |
|
29 |
def parquet_opener(data, mode='train', tts_data={}):
|
@@ -54,6 +54,7 @@ def parquet_opener(data, mode='train', tts_data={}):
|
|
54 |
except Exception as ex:
|
55 |
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
56 |
|
|
|
57 |
def filter(data,
|
58 |
max_length=10240,
|
59 |
min_length=10,
|
|
|
23 |
|
24 |
torchaudio.set_audio_backend('soundfile')
|
25 |
|
26 |
+
AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
|
27 |
|
28 |
|
29 |
def parquet_opener(data, mode='train', tts_data={}):
|
|
|
54 |
except Exception as ex:
|
55 |
logging.warning('Failed to open {}, ex info {}'.format(url, ex))
|
56 |
|
57 |
+
|
58 |
def filter(data,
|
59 |
max_length=10240,
|
60 |
min_length=10,
|
cosyvoice/flow/decoder.py
CHANGED
@@ -74,7 +74,7 @@ class ConditionalDecoder(nn.Module):
|
|
74 |
)
|
75 |
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
76 |
|
77 |
-
for
|
78 |
input_channel = channels[-1]
|
79 |
out_channels = channels[-1]
|
80 |
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
@@ -126,7 +126,6 @@ class ConditionalDecoder(nn.Module):
|
|
126 |
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
127 |
self.initialize_weights()
|
128 |
|
129 |
-
|
130 |
def initialize_weights(self):
|
131 |
for m in self.modules():
|
132 |
if isinstance(m, nn.Conv1d):
|
|
|
74 |
)
|
75 |
self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
|
76 |
|
77 |
+
for _ in range(num_mid_blocks):
|
78 |
input_channel = channels[-1]
|
79 |
out_channels = channels[-1]
|
80 |
resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
|
|
|
126 |
self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
|
127 |
self.initialize_weights()
|
128 |
|
|
|
129 |
def initialize_weights(self):
|
130 |
for m in self.modules():
|
131 |
if isinstance(m, nn.Conv1d):
|
cosyvoice/flow/flow.py
CHANGED
@@ -33,8 +33,13 @@ class MaskedDiffWithXvec(torch.nn.Module):
|
|
33 |
encoder: torch.nn.Module = None,
|
34 |
length_regulator: torch.nn.Module = None,
|
35 |
decoder: torch.nn.Module = None,
|
36 |
-
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
37 |
-
|
|
|
|
|
|
|
|
|
|
|
38 |
super().__init__()
|
39 |
self.input_size = input_size
|
40 |
self.output_size = output_size
|
|
|
33 |
encoder: torch.nn.Module = None,
|
34 |
length_regulator: torch.nn.Module = None,
|
35 |
decoder: torch.nn.Module = None,
|
36 |
+
decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
|
37 |
+
'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
|
38 |
+
'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
|
39 |
+
'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
|
40 |
+
'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
|
41 |
+
mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
|
42 |
+
'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
|
43 |
super().__init__()
|
44 |
self.input_size = input_size
|
45 |
self.output_size = output_size
|
cosyvoice/flow/flow_matching.py
CHANGED
@@ -15,6 +15,7 @@ import torch
|
|
15 |
import torch.nn.functional as F
|
16 |
from matcha.models.components.flow_matching import BASECFM
|
17 |
|
|
|
18 |
class ConditionalCFM(BASECFM):
|
19 |
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
20 |
super().__init__(
|
|
|
15 |
import torch.nn.functional as F
|
16 |
from matcha.models.components.flow_matching import BASECFM
|
17 |
|
18 |
+
|
19 |
class ConditionalCFM(BASECFM):
|
20 |
def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
|
21 |
super().__init__(
|
cosyvoice/flow/length_regulator.py
CHANGED
File without changes
|
cosyvoice/hifigan/f0_predictor.py
CHANGED
File without changes
|
cosyvoice/hifigan/generator.py
CHANGED
@@ -38,6 +38,8 @@ This code is modified from https://github.com/jik876/hifi-gan
|
|
38 |
https://github.com/NVIDIA/BigVGAN
|
39 |
|
40 |
"""
|
|
|
|
|
41 |
class ResBlock(torch.nn.Module):
|
42 |
"""Residual block module in HiFiGAN/BigVGAN."""
|
43 |
def __init__(
|
@@ -100,6 +102,7 @@ class ResBlock(torch.nn.Module):
|
|
100 |
remove_weight_norm(self.convs1[idx])
|
101 |
remove_weight_norm(self.convs2[idx])
|
102 |
|
|
|
103 |
class SineGen(torch.nn.Module):
|
104 |
""" Definition of sine generator
|
105 |
SineGen(samp_rate, harmonic_num = 0,
|
@@ -286,8 +289,7 @@ class HiFTGenerator(nn.Module):
|
|
286 |
self.source_resblocks = nn.ModuleList()
|
287 |
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
288 |
downsample_cum_rates = np.cumprod(downsample_rates)
|
289 |
-
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
|
290 |
-
source_resblock_dilation_sizes)):
|
291 |
if u == 1:
|
292 |
self.source_downs.append(
|
293 |
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
@@ -304,7 +306,7 @@ class HiFTGenerator(nn.Module):
|
|
304 |
self.resblocks = nn.ModuleList()
|
305 |
for i in range(len(self.ups)):
|
306 |
ch = base_channels // (2**(i + 1))
|
307 |
-
for
|
308 |
self.resblocks.append(ResBlock(ch, k, d))
|
309 |
|
310 |
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
@@ -332,7 +334,8 @@ class HiFTGenerator(nn.Module):
|
|
332 |
magnitude = torch.clip(magnitude, max=1e2)
|
333 |
real = magnitude * torch.cos(phase)
|
334 |
img = magnitude * torch.sin(phase)
|
335 |
-
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
|
|
336 |
return inverse_transform
|
337 |
|
338 |
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
|
|
38 |
https://github.com/NVIDIA/BigVGAN
|
39 |
|
40 |
"""
|
41 |
+
|
42 |
+
|
43 |
class ResBlock(torch.nn.Module):
|
44 |
"""Residual block module in HiFiGAN/BigVGAN."""
|
45 |
def __init__(
|
|
|
102 |
remove_weight_norm(self.convs1[idx])
|
103 |
remove_weight_norm(self.convs2[idx])
|
104 |
|
105 |
+
|
106 |
class SineGen(torch.nn.Module):
|
107 |
""" Definition of sine generator
|
108 |
SineGen(samp_rate, harmonic_num = 0,
|
|
|
289 |
self.source_resblocks = nn.ModuleList()
|
290 |
downsample_rates = [1] + upsample_rates[::-1][:-1]
|
291 |
downsample_cum_rates = np.cumprod(downsample_rates)
|
292 |
+
for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
|
|
|
293 |
if u == 1:
|
294 |
self.source_downs.append(
|
295 |
Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
|
|
|
306 |
self.resblocks = nn.ModuleList()
|
307 |
for i in range(len(self.ups)):
|
308 |
ch = base_channels // (2**(i + 1))
|
309 |
+
for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
|
310 |
self.resblocks.append(ResBlock(ch, k, d))
|
311 |
|
312 |
self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
|
|
|
334 |
magnitude = torch.clip(magnitude, max=1e2)
|
335 |
real = magnitude * torch.cos(phase)
|
336 |
img = magnitude * torch.sin(phase)
|
337 |
+
inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
|
338 |
+
self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
|
339 |
return inverse_transform
|
340 |
|
341 |
def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
|
cosyvoice/llm/llm.py
CHANGED
@@ -80,7 +80,8 @@ class TransformerLM(torch.nn.Module):
|
|
80 |
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
81 |
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
82 |
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
83 |
-
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
|
|
84 |
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
85 |
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
86 |
return lm_input, lm_input_len
|
@@ -104,7 +105,8 @@ class TransformerLM(torch.nn.Module):
|
|
104 |
embedding = batch['embedding'].to(device)
|
105 |
|
106 |
# 1. prepare llm_target
|
107 |
-
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
|
|
108 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
109 |
|
110 |
# 1. encode text_token
|
@@ -124,7 +126,8 @@ class TransformerLM(torch.nn.Module):
|
|
124 |
speech_token = self.speech_embedding(speech_token)
|
125 |
|
126 |
# 5. unpad and pad
|
127 |
-
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
|
|
128 |
|
129 |
# 6. run lm forward
|
130 |
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
@@ -194,8 +197,10 @@ class TransformerLM(torch.nn.Module):
|
|
194 |
offset = 0
|
195 |
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
196 |
for i in range(max_len):
|
197 |
-
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
|
198 |
-
|
|
|
|
|
199 |
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
200 |
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
201 |
if top_ids == self.speech_token_size:
|
|
|
80 |
def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
|
81 |
text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
|
82 |
speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
|
83 |
+
lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
|
84 |
+
for i in range(len(text_token))]
|
85 |
lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
|
86 |
lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
|
87 |
return lm_input, lm_input_len
|
|
|
105 |
embedding = batch['embedding'].to(device)
|
106 |
|
107 |
# 1. prepare llm_target
|
108 |
+
lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
|
109 |
+
[self.speech_token_size]) for i in range(text_token.size(0))]
|
110 |
lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
|
111 |
|
112 |
# 1. encode text_token
|
|
|
126 |
speech_token = self.speech_embedding(speech_token)
|
127 |
|
128 |
# 5. unpad and pad
|
129 |
+
lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
|
130 |
+
task_id_emb, speech_token, speech_token_len)
|
131 |
|
132 |
# 6. run lm forward
|
133 |
lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
|
|
|
197 |
offset = 0
|
198 |
att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
|
199 |
for i in range(max_len):
|
200 |
+
y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
|
201 |
+
att_cache=att_cache, cnn_cache=cnn_cache,
|
202 |
+
att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
|
203 |
+
device=lm_input.device)).to(torch.bool))
|
204 |
logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
|
205 |
top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
|
206 |
if top_ids == self.speech_token_size:
|
cosyvoice/transformer/embedding.py
CHANGED
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|
212 |
|
213 |
"""
|
214 |
|
215 |
-
def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
|
216 |
"""Construct an PositionalEncoding object."""
|
217 |
super(EspnetRelPositionalEncoding, self).__init__()
|
218 |
self.d_model = d_model
|
@@ -289,6 +289,6 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
|
|
289 |
"""
|
290 |
pos_emb = self.pe[
|
291 |
:,
|
292 |
-
self.pe.size(1) // 2 - size + 1
|
293 |
]
|
294 |
return pos_emb
|
|
|
212 |
|
213 |
"""
|
214 |
|
215 |
+
def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
|
216 |
"""Construct an PositionalEncoding object."""
|
217 |
super(EspnetRelPositionalEncoding, self).__init__()
|
218 |
self.d_model = d_model
|
|
|
289 |
"""
|
290 |
pos_emb = self.pe[
|
291 |
:,
|
292 |
+
self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
|
293 |
]
|
294 |
return pos_emb
|
cosyvoice/utils/common.py
CHANGED
@@ -102,6 +102,7 @@ def init_weights(m, mean=0.0, std=0.01):
|
|
102 |
if classname.find("Conv") != -1:
|
103 |
m.weight.data.normal_(mean, std)
|
104 |
|
|
|
105 |
# Repetition Aware Sampling in VALL-E 2
|
106 |
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
107 |
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
@@ -110,6 +111,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
|
|
110 |
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
111 |
return top_ids
|
112 |
|
|
|
113 |
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
114 |
prob, indices = [], []
|
115 |
cum_prob = 0.0
|
@@ -127,13 +129,16 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
|
127 |
top_ids = indices[prob.multinomial(1, replacement=True)]
|
128 |
return top_ids
|
129 |
|
|
|
130 |
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
131 |
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
132 |
return top_ids
|
133 |
|
|
|
134 |
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
135 |
device = fade_in_mel.device
|
136 |
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
137 |
mel_overlap_len = int(window.shape[0] / 2)
|
138 |
-
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] +
|
|
|
139 |
return fade_in_mel.to(device)
|
|
|
102 |
if classname.find("Conv") != -1:
|
103 |
m.weight.data.normal_(mean, std)
|
104 |
|
105 |
+
|
106 |
# Repetition Aware Sampling in VALL-E 2
|
107 |
def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
|
108 |
top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
|
|
|
111 |
top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
|
112 |
return top_ids
|
113 |
|
114 |
+
|
115 |
def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
|
116 |
prob, indices = [], []
|
117 |
cum_prob = 0.0
|
|
|
129 |
top_ids = indices[prob.multinomial(1, replacement=True)]
|
130 |
return top_ids
|
131 |
|
132 |
+
|
133 |
def random_sampling(weighted_scores, decoded_tokens, sampling):
|
134 |
top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
|
135 |
return top_ids
|
136 |
|
137 |
+
|
138 |
def fade_in_out(fade_in_mel, fade_out_mel, window):
|
139 |
device = fade_in_mel.device
|
140 |
fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
|
141 |
mel_overlap_len = int(window.shape[0] / 2)
|
142 |
+
fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
|
143 |
+
fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
|
144 |
return fade_in_mel.to(device)
|
cosyvoice/utils/executor.py
CHANGED
@@ -70,7 +70,8 @@ class Executor:
|
|
70 |
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
|
71 |
log_per_step(writer, info_dict)
|
72 |
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
73 |
-
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and
|
|
|
74 |
dist.barrier()
|
75 |
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
76 |
model.train()
|
|
|
70 |
info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
|
71 |
log_per_step(writer, info_dict)
|
72 |
# NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
|
73 |
+
if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
|
74 |
+
(batch_idx + 1) % info_dict["accum_grad"] == 0:
|
75 |
dist.barrier()
|
76 |
self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
|
77 |
model.train()
|
cosyvoice/utils/file_utils.py
CHANGED
@@ -28,6 +28,7 @@ def read_lists(list_file):
|
|
28 |
lists.append(line.strip())
|
29 |
return lists
|
30 |
|
|
|
31 |
def read_json_lists(list_file):
|
32 |
lists = read_lists(list_file)
|
33 |
results = {}
|
@@ -36,6 +37,7 @@ def read_json_lists(list_file):
|
|
36 |
results.update(json.load(fin))
|
37 |
return results
|
38 |
|
|
|
39 |
def load_wav(wav, target_sr):
|
40 |
speech, sample_rate = torchaudio.load(wav)
|
41 |
speech = speech.mean(dim=0, keepdim=True)
|
@@ -44,6 +46,7 @@ def load_wav(wav, target_sr):
|
|
44 |
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
45 |
return speech
|
46 |
|
|
|
47 |
def speed_change(waveform, sample_rate, speed_factor: str):
|
48 |
effects = [
|
49 |
["tempo", speed_factor], # speed_factor
|
|
|
28 |
lists.append(line.strip())
|
29 |
return lists
|
30 |
|
31 |
+
|
32 |
def read_json_lists(list_file):
|
33 |
lists = read_lists(list_file)
|
34 |
results = {}
|
|
|
37 |
results.update(json.load(fin))
|
38 |
return results
|
39 |
|
40 |
+
|
41 |
def load_wav(wav, target_sr):
|
42 |
speech, sample_rate = torchaudio.load(wav)
|
43 |
speech = speech.mean(dim=0, keepdim=True)
|
|
|
46 |
speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
|
47 |
return speech
|
48 |
|
49 |
+
|
50 |
def speed_change(waveform, sample_rate, speed_factor: str):
|
51 |
effects = [
|
52 |
["tempo", speed_factor], # speed_factor
|
cosyvoice/utils/frontend_utils.py
CHANGED
@@ -15,6 +15,7 @@
|
|
15 |
import re
|
16 |
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
17 |
|
|
|
18 |
# whether contain chinese character
|
19 |
def contains_chinese(text):
|
20 |
return bool(chinese_char_pattern.search(text))
|
|
|
15 |
import re
|
16 |
chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
|
17 |
|
18 |
+
|
19 |
# whether contain chinese character
|
20 |
def contains_chinese(text):
|
21 |
return bool(chinese_char_pattern.search(text))
|
cosyvoice/utils/scheduler.py
CHANGED
@@ -567,8 +567,7 @@ class NoamAnnealing(_LRScheduler):
|
|
567 |
min_lr=0.0,
|
568 |
last_epoch=-1):
|
569 |
self._normalize = d_model**(-0.5)
|
570 |
-
assert not (warmup_steps is not None
|
571 |
-
and warmup_ratio is not None), \
|
572 |
"Either use particular number of step or ratio"
|
573 |
assert warmup_ratio is None or max_steps is not None, \
|
574 |
"If there is a ratio, there should be a total steps"
|
|
|
567 |
min_lr=0.0,
|
568 |
last_epoch=-1):
|
569 |
self._normalize = d_model**(-0.5)
|
570 |
+
assert not (warmup_steps is not None and warmup_ratio is not None), \
|
|
|
571 |
"Either use particular number of step or ratio"
|
572 |
assert warmup_ratio is None or max_steps is not None, \
|
573 |
"If there is a ratio, there should be a total steps"
|
cosyvoice/utils/train_utils.py
CHANGED
@@ -69,7 +69,6 @@ def init_dataset_and_dataloader(args, configs):
|
|
69 |
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
|
70 |
|
71 |
|
72 |
-
|
73 |
def check_modify_and_save_config(args, configs):
|
74 |
if args.train_engine == "torch_ddp":
|
75 |
configs['train_conf']["dtype"] = 'fp32'
|
@@ -84,7 +83,8 @@ def check_modify_and_save_config(args, configs):
|
|
84 |
configs['train_conf']["dtype"] = "fp32"
|
85 |
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
|
86 |
# if use deepspeed, override ddp config
|
87 |
-
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
|
|
|
88 |
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
|
89 |
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
|
90 |
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
|
|
|
69 |
return train_dataset, cv_dataset, train_data_loader, cv_data_loader
|
70 |
|
71 |
|
|
|
72 |
def check_modify_and_save_config(args, configs):
|
73 |
if args.train_engine == "torch_ddp":
|
74 |
configs['train_conf']["dtype"] = 'fp32'
|
|
|
83 |
configs['train_conf']["dtype"] = "fp32"
|
84 |
assert ds_configs["train_micro_batch_size_per_gpu"] == 1
|
85 |
# if use deepspeed, override ddp config
|
86 |
+
configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
|
87 |
+
configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
|
88 |
configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
|
89 |
configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
|
90 |
configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
|
examples/libritts/cosyvoice/local/prepare_data.py
CHANGED
@@ -7,6 +7,7 @@ from tqdm import tqdm
|
|
7 |
|
8 |
logger = logging.getLogger()
|
9 |
|
|
|
10 |
def main():
|
11 |
wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
|
12 |
|
@@ -41,6 +42,7 @@ def main():
|
|
41 |
f.write('{} {}\n'.format(k, ' '.join(v)))
|
42 |
return
|
43 |
|
|
|
44 |
if __name__ == "__main__":
|
45 |
parser = argparse.ArgumentParser()
|
46 |
parser.add_argument('--src_dir',
|
|
|
7 |
|
8 |
logger = logging.getLogger()
|
9 |
|
10 |
+
|
11 |
def main():
|
12 |
wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
|
13 |
|
|
|
42 |
f.write('{} {}\n'.format(k, ' '.join(v)))
|
43 |
return
|
44 |
|
45 |
+
|
46 |
if __name__ == "__main__":
|
47 |
parser = argparse.ArgumentParser()
|
48 |
parser.add_argument('--src_dir',
|
examples/libritts/cosyvoice/run.sh
CHANGED
@@ -83,7 +83,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
83 |
fi
|
84 |
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
85 |
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
86 |
-
for model in llm; do
|
87 |
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
88 |
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
89 |
cosyvoice/bin/train.py \
|
|
|
83 |
fi
|
84 |
cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
|
85 |
cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
|
86 |
+
for model in llm flow; do
|
87 |
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
88 |
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
89 |
cosyvoice/bin/train.py \
|
examples/magicdata-read/cosyvoice/local/prepare_data.py
CHANGED
@@ -6,6 +6,7 @@ from tqdm import tqdm
|
|
6 |
|
7 |
logger = logging.getLogger()
|
8 |
|
|
|
9 |
def main():
|
10 |
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
|
11 |
with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
|
@@ -40,6 +41,7 @@ def main():
|
|
40 |
f.write('{} {}\n'.format(k, ' '.join(v)))
|
41 |
return
|
42 |
|
|
|
43 |
if __name__ == "__main__":
|
44 |
parser = argparse.ArgumentParser()
|
45 |
parser.add_argument('--src_dir',
|
|
|
6 |
|
7 |
logger = logging.getLogger()
|
8 |
|
9 |
+
|
10 |
def main():
|
11 |
utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
|
12 |
with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
|
|
|
41 |
f.write('{} {}\n'.format(k, ' '.join(v)))
|
42 |
return
|
43 |
|
44 |
+
|
45 |
if __name__ == "__main__":
|
46 |
parser = argparse.ArgumentParser()
|
47 |
parser.add_argument('--src_dir',
|
examples/magicdata-read/cosyvoice/run.sh
CHANGED
@@ -83,7 +83,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
|
|
83 |
fi
|
84 |
cp data/train/parquet/data.list data/train.data.list
|
85 |
cp data/dev/parquet/data.list data/dev.data.list
|
86 |
-
for model in llm; do
|
87 |
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
88 |
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
89 |
cosyvoice/bin/train.py \
|
|
|
83 |
fi
|
84 |
cp data/train/parquet/data.list data/train.data.list
|
85 |
cp data/dev/parquet/data.list data/dev.data.list
|
86 |
+
for model in llm flow; do
|
87 |
torchrun --nnodes=1 --nproc_per_node=$num_gpus \
|
88 |
--rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
|
89 |
cosyvoice/bin/train.py \
|
runtime/python/fastapi/client.py
CHANGED
@@ -38,7 +38,7 @@ def main():
|
|
38 |
payload = {
|
39 |
'tts_text': args.tts_text,
|
40 |
}
|
41 |
-
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
|
42 |
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
43 |
else:
|
44 |
payload = {
|
@@ -55,6 +55,7 @@ def main():
|
|
55 |
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
56 |
logging.info('get response')
|
57 |
|
|
|
58 |
if __name__ == "__main__":
|
59 |
parser = argparse.ArgumentParser()
|
60 |
parser.add_argument('--host',
|
@@ -81,7 +82,8 @@ if __name__ == "__main__":
|
|
81 |
default='../../../zero_shot_prompt.wav')
|
82 |
parser.add_argument('--instruct_text',
|
83 |
type=str,
|
84 |
-
default='Theo \'Crimson\', is a fiery, passionate rebel leader.
|
|
|
85 |
parser.add_argument('--tts_wav',
|
86 |
type=str,
|
87 |
default='demo.wav')
|
|
|
38 |
payload = {
|
39 |
'tts_text': args.tts_text,
|
40 |
}
|
41 |
+
files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
|
42 |
response = requests.request("GET", url, data=payload, files=files, stream=True)
|
43 |
else:
|
44 |
payload = {
|
|
|
55 |
torchaudio.save(args.tts_wav, tts_speech, target_sr)
|
56 |
logging.info('get response')
|
57 |
|
58 |
+
|
59 |
if __name__ == "__main__":
|
60 |
parser = argparse.ArgumentParser()
|
61 |
parser.add_argument('--host',
|
|
|
82 |
default='../../../zero_shot_prompt.wav')
|
83 |
parser.add_argument('--instruct_text',
|
84 |
type=str,
|
85 |
+
default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
|
86 |
+
Fights with fervor for justice, but struggles with impulsiveness.')
|
87 |
parser.add_argument('--tts_wav',
|
88 |
type=str,
|
89 |
default='demo.wav')
|
runtime/python/fastapi/server.py
CHANGED
@@ -13,9 +13,6 @@
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
16 |
-
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
17 |
-
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
18 |
-
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
19 |
import argparse
|
20 |
import logging
|
21 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
@@ -24,6 +21,9 @@ from fastapi.responses import StreamingResponse
|
|
24 |
from fastapi.middleware.cors import CORSMiddleware
|
25 |
import uvicorn
|
26 |
import numpy as np
|
|
|
|
|
|
|
27 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
28 |
from cosyvoice.utils.file_utils import load_wav
|
29 |
|
@@ -36,34 +36,40 @@ app.add_middleware(
|
|
36 |
allow_methods=["*"],
|
37 |
allow_headers=["*"])
|
38 |
|
|
|
39 |
def generate_data(model_output):
|
40 |
for i in model_output:
|
41 |
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
42 |
yield tts_audio
|
43 |
|
|
|
44 |
@app.get("/inference_sft")
|
45 |
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
|
46 |
model_output = cosyvoice.inference_sft(tts_text, spk_id)
|
47 |
return StreamingResponse(generate_data(model_output))
|
48 |
|
|
|
49 |
@app.get("/inference_zero_shot")
|
50 |
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
|
51 |
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
52 |
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
|
53 |
return StreamingResponse(generate_data(model_output))
|
54 |
|
|
|
55 |
@app.get("/inference_cross_lingual")
|
56 |
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
|
57 |
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
58 |
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
|
59 |
return StreamingResponse(generate_data(model_output))
|
60 |
|
|
|
61 |
@app.get("/inference_instruct")
|
62 |
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
|
63 |
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
|
64 |
return StreamingResponse(generate_data(model_output))
|
65 |
|
66 |
-
|
|
|
67 |
parser = argparse.ArgumentParser()
|
68 |
parser.add_argument('--port',
|
69 |
type=int,
|
@@ -74,4 +80,4 @@ if __name__=='__main__':
|
|
74 |
help='local path or modelscope repo id')
|
75 |
args = parser.parse_args()
|
76 |
cosyvoice = CosyVoice(args.model_dir)
|
77 |
-
uvicorn.run(app, host="127.0.0.1", port=args.port)
|
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
|
|
|
|
|
|
16 |
import argparse
|
17 |
import logging
|
18 |
logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
|
21 |
from fastapi.middleware.cors import CORSMiddleware
|
22 |
import uvicorn
|
23 |
import numpy as np
|
24 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
25 |
+
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
26 |
+
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
27 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
28 |
from cosyvoice.utils.file_utils import load_wav
|
29 |
|
|
|
36 |
allow_methods=["*"],
|
37 |
allow_headers=["*"])
|
38 |
|
39 |
+
|
40 |
def generate_data(model_output):
|
41 |
for i in model_output:
|
42 |
tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
43 |
yield tts_audio
|
44 |
|
45 |
+
|
46 |
@app.get("/inference_sft")
|
47 |
async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
|
48 |
model_output = cosyvoice.inference_sft(tts_text, spk_id)
|
49 |
return StreamingResponse(generate_data(model_output))
|
50 |
|
51 |
+
|
52 |
@app.get("/inference_zero_shot")
|
53 |
async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
|
54 |
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
55 |
model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
|
56 |
return StreamingResponse(generate_data(model_output))
|
57 |
|
58 |
+
|
59 |
@app.get("/inference_cross_lingual")
|
60 |
async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
|
61 |
prompt_speech_16k = load_wav(prompt_wav.file, 16000)
|
62 |
model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
|
63 |
return StreamingResponse(generate_data(model_output))
|
64 |
|
65 |
+
|
66 |
@app.get("/inference_instruct")
|
67 |
async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
|
68 |
model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
|
69 |
return StreamingResponse(generate_data(model_output))
|
70 |
|
71 |
+
|
72 |
+
if __name__ == '__main__':
|
73 |
parser = argparse.ArgumentParser()
|
74 |
parser.add_argument('--port',
|
75 |
type=int,
|
|
|
80 |
help='local path or modelscope repo id')
|
81 |
args = parser.parse_args()
|
82 |
cosyvoice = CosyVoice(args.model_dir)
|
83 |
+
uvicorn.run(app, host="127.0.0.1", port=args.port)
|
runtime/python/grpc/client.py
CHANGED
@@ -96,7 +96,8 @@ if __name__ == "__main__":
|
|
96 |
default='../../../zero_shot_prompt.wav')
|
97 |
parser.add_argument('--instruct_text',
|
98 |
type=str,
|
99 |
-
default='Theo \'Crimson\', is a fiery, passionate rebel leader.
|
|
|
100 |
parser.add_argument('--tts_wav',
|
101 |
type=str,
|
102 |
default='demo.wav')
|
|
|
96 |
default='../../../zero_shot_prompt.wav')
|
97 |
parser.add_argument('--instruct_text',
|
98 |
type=str,
|
99 |
+
default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
|
100 |
+
Fights with fervor for justice, but struggles with impulsiveness.')
|
101 |
parser.add_argument('--tts_wav',
|
102 |
type=str,
|
103 |
default='demo.wav')
|
runtime/python/grpc/server.py
CHANGED
@@ -13,9 +13,6 @@
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
16 |
-
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
17 |
-
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
18 |
-
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
19 |
from concurrent import futures
|
20 |
import argparse
|
21 |
import cosyvoice_pb2
|
@@ -25,11 +22,15 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
|
|
25 |
import grpc
|
26 |
import torch
|
27 |
import numpy as np
|
|
|
|
|
|
|
28 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
29 |
|
30 |
logging.basicConfig(level=logging.DEBUG,
|
31 |
format='%(asctime)s %(levelname)s %(message)s')
|
32 |
|
|
|
33 |
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
34 |
def __init__(self, args):
|
35 |
self.cosyvoice = CosyVoice(args.model_dir)
|
@@ -43,7 +44,9 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
|
43 |
logging.info('get zero_shot inference request')
|
44 |
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
45 |
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
46 |
-
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
|
|
|
|
|
47 |
elif request.HasField('cross_lingual_request'):
|
48 |
logging.info('get cross_lingual inference request')
|
49 |
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
@@ -51,7 +54,9 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
|
51 |
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
|
52 |
else:
|
53 |
logging.info('get instruct inference request')
|
54 |
-
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
|
|
|
|
|
55 |
|
56 |
logging.info('send inference response')
|
57 |
for i in model_output:
|
@@ -59,6 +64,7 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
|
59 |
response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
60 |
yield response
|
61 |
|
|
|
62 |
def main():
|
63 |
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
|
64 |
cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
|
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
|
|
|
|
|
|
16 |
from concurrent import futures
|
17 |
import argparse
|
18 |
import cosyvoice_pb2
|
|
|
22 |
import grpc
|
23 |
import torch
|
24 |
import numpy as np
|
25 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
26 |
+
sys.path.append('{}/../../..'.format(ROOT_DIR))
|
27 |
+
sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
|
28 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
29 |
|
30 |
logging.basicConfig(level=logging.DEBUG,
|
31 |
format='%(asctime)s %(levelname)s %(message)s')
|
32 |
|
33 |
+
|
34 |
class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
|
35 |
def __init__(self, args):
|
36 |
self.cosyvoice = CosyVoice(args.model_dir)
|
|
|
44 |
logging.info('get zero_shot inference request')
|
45 |
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
46 |
prompt_speech_16k = prompt_speech_16k.float() / (2**15)
|
47 |
+
model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
|
48 |
+
request.zero_shot_request.prompt_text,
|
49 |
+
prompt_speech_16k)
|
50 |
elif request.HasField('cross_lingual_request'):
|
51 |
logging.info('get cross_lingual inference request')
|
52 |
prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
|
|
|
54 |
model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
|
55 |
else:
|
56 |
logging.info('get instruct inference request')
|
57 |
+
model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
|
58 |
+
request.instruct_request.spk_id,
|
59 |
+
request.instruct_request.instruct_text)
|
60 |
|
61 |
logging.info('send inference response')
|
62 |
for i in model_output:
|
|
|
64 |
response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
|
65 |
yield response
|
66 |
|
67 |
+
|
68 |
def main():
|
69 |
grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
|
70 |
cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
|
tools/extract_embedding.py
CHANGED
@@ -59,6 +59,7 @@ def main(args):
|
|
59 |
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
|
60 |
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
|
61 |
|
|
|
62 |
if __name__ == "__main__":
|
63 |
parser = argparse.ArgumentParser()
|
64 |
parser.add_argument('--dir',
|
|
|
59 |
torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
|
60 |
torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
|
61 |
|
62 |
+
|
63 |
if __name__ == "__main__":
|
64 |
parser = argparse.ArgumentParser()
|
65 |
parser.add_argument('--dir',
|
tools/make_parquet_list.py
CHANGED
@@ -53,6 +53,7 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
|
|
53 |
json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
|
54 |
logging.info('spend time {}'.format(time.time() - start_time))
|
55 |
|
|
|
56 |
if __name__ == "__main__":
|
57 |
parser = argparse.ArgumentParser()
|
58 |
parser.add_argument('--num_utts_per_parquet',
|
|
|
53 |
json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
|
54 |
logging.info('spend time {}'.format(time.time() - start_time))
|
55 |
|
56 |
+
|
57 |
if __name__ == "__main__":
|
58 |
parser = argparse.ArgumentParser()
|
59 |
parser.add_argument('--num_utts_per_parquet',
|
webui.py
CHANGED
@@ -13,9 +13,6 @@
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
16 |
-
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
17 |
-
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
|
18 |
-
|
19 |
import argparse
|
20 |
import gradio as gr
|
21 |
import numpy as np
|
@@ -23,9 +20,19 @@ import torch
|
|
23 |
import torchaudio
|
24 |
import random
|
25 |
import librosa
|
26 |
-
|
|
|
27 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
28 |
-
from cosyvoice.utils.file_utils import load_wav,
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
29 |
|
30 |
def generate_seed():
|
31 |
seed = random.randint(1, 100000000)
|
@@ -34,13 +41,14 @@ def generate_seed():
|
|
34 |
"value": seed
|
35 |
}
|
36 |
|
|
|
37 |
def set_all_random_seed(seed):
|
38 |
random.seed(seed)
|
39 |
np.random.seed(seed)
|
40 |
torch.manual_seed(seed)
|
41 |
torch.cuda.manual_seed_all(seed)
|
42 |
|
43 |
-
|
44 |
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
|
45 |
speech, _ = librosa.effects.trim(
|
46 |
speech, top_db=top_db,
|
@@ -52,16 +60,13 @@ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
|
|
52 |
speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
|
53 |
return speech
|
54 |
|
55 |
-
|
56 |
-
instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
|
57 |
-
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
|
58 |
-
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
|
59 |
-
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
|
60 |
-
stream_mode_list = [('否', False), ('是', True)]
|
61 |
def change_instruction(mode_checkbox_group):
|
62 |
return instruct_dict[mode_checkbox_group]
|
63 |
|
64 |
-
|
|
|
|
|
65 |
if prompt_wav_upload is not None:
|
66 |
prompt_wav = prompt_wav_upload
|
67 |
elif prompt_wav_record is not None:
|
@@ -72,31 +77,31 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
|
|
72 |
if mode_checkbox_group in ['自然语言控制']:
|
73 |
if cosyvoice.frontend.instruct is False:
|
74 |
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
|
75 |
-
|
76 |
if instruct_text == '':
|
77 |
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
|
78 |
-
|
79 |
if prompt_wav is not None or prompt_text != '':
|
80 |
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
|
81 |
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
|
82 |
if mode_checkbox_group in ['跨语种复刻']:
|
83 |
if cosyvoice.frontend.instruct is True:
|
84 |
gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
|
85 |
-
|
86 |
if instruct_text != '':
|
87 |
gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
|
88 |
if prompt_wav is None:
|
89 |
gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
|
90 |
-
|
91 |
gr.Info('您正在使用跨语种复刻模式, 请确保合成文本��prompt文本为不同语言')
|
92 |
# if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
|
93 |
if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
|
94 |
if prompt_wav is None:
|
95 |
gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
|
96 |
-
|
97 |
if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
|
98 |
gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
|
99 |
-
|
100 |
# sft mode only use sft_dropdown
|
101 |
if mode_checkbox_group in ['预训练音色']:
|
102 |
if instruct_text != '' or prompt_wav is not None or prompt_text != '':
|
@@ -105,7 +110,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
|
|
105 |
if mode_checkbox_group in ['3s极速复刻']:
|
106 |
if prompt_text == '':
|
107 |
gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
|
108 |
-
|
109 |
if instruct_text != '':
|
110 |
gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
|
111 |
|
@@ -113,28 +118,32 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
|
|
113 |
logging.info('get sft inference request')
|
114 |
set_all_random_seed(seed)
|
115 |
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
|
116 |
-
yield (target_sr,
|
117 |
elif mode_checkbox_group == '3s极速复刻':
|
118 |
logging.info('get zero_shot inference request')
|
119 |
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
120 |
set_all_random_seed(seed)
|
121 |
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
|
122 |
-
yield (target_sr,
|
123 |
elif mode_checkbox_group == '跨语种复刻':
|
124 |
logging.info('get cross_lingual inference request')
|
125 |
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
126 |
set_all_random_seed(seed)
|
127 |
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
|
128 |
-
yield (target_sr,
|
129 |
else:
|
130 |
logging.info('get instruct inference request')
|
131 |
set_all_random_seed(seed)
|
132 |
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
|
133 |
-
yield (target_sr,
|
|
|
134 |
|
135 |
def main():
|
136 |
with gr.Blocks() as demo:
|
137 |
-
gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice)
|
|
|
|
|
|
|
138 |
gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
|
139 |
|
140 |
tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
|
@@ -160,12 +169,14 @@ def main():
|
|
160 |
|
161 |
seed_button.click(generate_seed, inputs=[], outputs=seed)
|
162 |
generate_button.click(generate_audio,
|
163 |
-
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
|
|
|
164 |
outputs=[audio_output])
|
165 |
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
|
166 |
demo.queue(max_size=4, default_concurrency_limit=2)
|
167 |
demo.launch(server_name='0.0.0.0', server_port=args.port)
|
168 |
|
|
|
169 |
if __name__ == '__main__':
|
170 |
parser = argparse.ArgumentParser()
|
171 |
parser.add_argument('--port',
|
|
|
13 |
# limitations under the License.
|
14 |
import os
|
15 |
import sys
|
|
|
|
|
|
|
16 |
import argparse
|
17 |
import gradio as gr
|
18 |
import numpy as np
|
|
|
20 |
import torchaudio
|
21 |
import random
|
22 |
import librosa
|
23 |
+
ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
|
24 |
+
sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
|
25 |
from cosyvoice.cli.cosyvoice import CosyVoice
|
26 |
+
from cosyvoice.utils.file_utils import load_wav, logging
|
27 |
+
|
28 |
+
inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
|
29 |
+
instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
|
30 |
+
'3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
|
31 |
+
'跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
|
32 |
+
'自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
|
33 |
+
stream_mode_list = [('否', False), ('是', True)]
|
34 |
+
max_val = 0.8
|
35 |
+
|
36 |
|
37 |
def generate_seed():
|
38 |
seed = random.randint(1, 100000000)
|
|
|
41 |
"value": seed
|
42 |
}
|
43 |
|
44 |
+
|
45 |
def set_all_random_seed(seed):
|
46 |
random.seed(seed)
|
47 |
np.random.seed(seed)
|
48 |
torch.manual_seed(seed)
|
49 |
torch.cuda.manual_seed_all(seed)
|
50 |
|
51 |
+
|
52 |
def postprocess(speech, top_db=60, hop_length=220, win_length=440):
|
53 |
speech, _ = librosa.effects.trim(
|
54 |
speech, top_db=top_db,
|
|
|
60 |
speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
|
61 |
return speech
|
62 |
|
63 |
+
|
|
|
|
|
|
|
|
|
|
|
64 |
def change_instruction(mode_checkbox_group):
|
65 |
return instruct_dict[mode_checkbox_group]
|
66 |
|
67 |
+
|
68 |
+
def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
|
69 |
+
seed, stream, speed_factor):
|
70 |
if prompt_wav_upload is not None:
|
71 |
prompt_wav = prompt_wav_upload
|
72 |
elif prompt_wav_record is not None:
|
|
|
77 |
if mode_checkbox_group in ['自然语言控制']:
|
78 |
if cosyvoice.frontend.instruct is False:
|
79 |
gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
|
80 |
+
yield (target_sr, default_data)
|
81 |
if instruct_text == '':
|
82 |
gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
|
83 |
+
yield (target_sr, default_data)
|
84 |
if prompt_wav is not None or prompt_text != '':
|
85 |
gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
|
86 |
# if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
|
87 |
if mode_checkbox_group in ['跨语种复刻']:
|
88 |
if cosyvoice.frontend.instruct is True:
|
89 |
gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
|
90 |
+
yield (target_sr, default_data)
|
91 |
if instruct_text != '':
|
92 |
gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
|
93 |
if prompt_wav is None:
|
94 |
gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
|
95 |
+
yield (target_sr, default_data)
|
96 |
gr.Info('您正在使用跨语种复刻模式, 请确保合成文本��prompt文本为不同语言')
|
97 |
# if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
|
98 |
if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
|
99 |
if prompt_wav is None:
|
100 |
gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
|
101 |
+
yield (target_sr, default_data)
|
102 |
if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
|
103 |
gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
|
104 |
+
yield (target_sr, default_data)
|
105 |
# sft mode only use sft_dropdown
|
106 |
if mode_checkbox_group in ['预训练音色']:
|
107 |
if instruct_text != '' or prompt_wav is not None or prompt_text != '':
|
|
|
110 |
if mode_checkbox_group in ['3s极速复刻']:
|
111 |
if prompt_text == '':
|
112 |
gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
|
113 |
+
yield (target_sr, default_data)
|
114 |
if instruct_text != '':
|
115 |
gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
|
116 |
|
|
|
118 |
logging.info('get sft inference request')
|
119 |
set_all_random_seed(seed)
|
120 |
for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
|
121 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
122 |
elif mode_checkbox_group == '3s极速复刻':
|
123 |
logging.info('get zero_shot inference request')
|
124 |
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
125 |
set_all_random_seed(seed)
|
126 |
for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
|
127 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
128 |
elif mode_checkbox_group == '跨语种复刻':
|
129 |
logging.info('get cross_lingual inference request')
|
130 |
prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
|
131 |
set_all_random_seed(seed)
|
132 |
for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
|
133 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
134 |
else:
|
135 |
logging.info('get instruct inference request')
|
136 |
set_all_random_seed(seed)
|
137 |
for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
|
138 |
+
yield (target_sr, i['tts_speech'].numpy().flatten())
|
139 |
+
|
140 |
|
141 |
def main():
|
142 |
with gr.Blocks() as demo:
|
143 |
+
gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
|
144 |
+
预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
|
145 |
+
[CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
|
146 |
+
[CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
|
147 |
gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
|
148 |
|
149 |
tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
|
|
|
169 |
|
170 |
seed_button.click(generate_seed, inputs=[], outputs=seed)
|
171 |
generate_button.click(generate_audio,
|
172 |
+
inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
|
173 |
+
seed, stream, speed_factor],
|
174 |
outputs=[audio_output])
|
175 |
mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
|
176 |
demo.queue(max_size=4, default_concurrency_limit=2)
|
177 |
demo.launch(server_name='0.0.0.0', server_port=args.port)
|
178 |
|
179 |
+
|
180 |
if __name__ == '__main__':
|
181 |
parser = argparse.ArgumentParser()
|
182 |
parser.add_argument('--port',
|