CosyVoice commited on
Commit
90433f5
·
1 Parent(s): eeebc45
.github/workflows/lint.yml CHANGED
@@ -51,5 +51,5 @@ jobs:
51
  set -eux
52
  pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
53
  flake8 --version
54
- flake8 --max-line-length 120 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
55
  if [ $? != 0 ]; then exit 1; fi
 
51
  set -eux
52
  pip install flake8==3.8.2 flake8-bugbear flake8-comprehensions flake8-executable flake8-pyi==20.5.0 mccabe pycodestyle==2.6.0 pyflakes==2.2.0
53
  flake8 --version
54
+ flake8 --max-line-length 150 --ignore B006,B008,B905,C408,E402,E741,W503,W504 --exclude ./third_party/,./runtime/python/grpc/cosyvoice_pb2*py
55
  if [ $? != 0 ]; then exit 1; fi
cosyvoice/bin/export_jit.py CHANGED
@@ -19,12 +19,13 @@ import logging
19
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
  import os
21
  import sys
 
22
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
23
  sys.path.append('{}/../..'.format(ROOT_DIR))
24
  sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
25
- import torch
26
  from cosyvoice.cli.cosyvoice import CosyVoice
27
 
 
28
  def get_args():
29
  parser = argparse.ArgumentParser(description='export your model for deployment')
30
  parser.add_argument('--model_dir',
@@ -35,6 +36,7 @@ def get_args():
35
  print(args)
36
  return args
37
 
 
38
  def main():
39
  args = get_args()
40
  logging.basicConfig(level=logging.DEBUG,
@@ -67,5 +69,6 @@ def main():
67
  script = torch.jit.optimize_for_inference(script)
68
  script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
69
 
 
70
  if __name__ == '__main__':
71
  main()
 
19
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
  import os
21
  import sys
22
+ import torch
23
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
  sys.path.append('{}/../..'.format(ROOT_DIR))
25
  sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
 
26
  from cosyvoice.cli.cosyvoice import CosyVoice
27
 
28
+
29
  def get_args():
30
  parser = argparse.ArgumentParser(description='export your model for deployment')
31
  parser.add_argument('--model_dir',
 
36
  print(args)
37
  return args
38
 
39
+
40
  def main():
41
  args = get_args()
42
  logging.basicConfig(level=logging.DEBUG,
 
69
  script = torch.jit.optimize_for_inference(script)
70
  script.save('{}/flow.encoder.fp32.zip'.format(args.model_dir))
71
 
72
+
73
  if __name__ == '__main__':
74
  main()
cosyvoice/bin/export_onnx.py CHANGED
@@ -20,13 +20,13 @@ import logging
20
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
  import os
22
  import sys
23
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
- sys.path.append('{}/../..'.format(ROOT_DIR))
25
- sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
26
  import onnxruntime
27
  import random
28
  import torch
29
  from tqdm import tqdm
 
 
 
30
  from cosyvoice.cli.cosyvoice import CosyVoice
31
 
32
 
@@ -50,6 +50,7 @@ def get_args():
50
  print(args)
51
  return args
52
 
 
53
  def main():
54
  args = get_args()
55
  logging.basicConfig(level=logging.DEBUG,
@@ -89,7 +90,8 @@ def main():
89
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
90
  option.intra_op_num_threads = 1
91
  providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
92
- estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir), sess_options=option, providers=providers)
 
93
 
94
  for _ in tqdm(range(10)):
95
  x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
@@ -105,5 +107,6 @@ def main():
105
  output_onnx = estimator_onnx.run(None, ort_inputs)[0]
106
  torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
107
 
 
108
  if __name__ == "__main__":
109
  main()
 
20
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
21
  import os
22
  import sys
 
 
 
23
  import onnxruntime
24
  import random
25
  import torch
26
  from tqdm import tqdm
27
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
28
+ sys.path.append('{}/../..'.format(ROOT_DIR))
29
+ sys.path.append('{}/../../third_party/Matcha-TTS'.format(ROOT_DIR))
30
  from cosyvoice.cli.cosyvoice import CosyVoice
31
 
32
 
 
50
  print(args)
51
  return args
52
 
53
+
54
  def main():
55
  args = get_args()
56
  logging.basicConfig(level=logging.DEBUG,
 
90
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
91
  option.intra_op_num_threads = 1
92
  providers = ['CUDAExecutionProvider' if torch.cuda.is_available() else 'CPUExecutionProvider']
93
+ estimator_onnx = onnxruntime.InferenceSession('{}/flow.decoder.estimator.fp32.onnx'.format(args.model_dir),
94
+ sess_options=option, providers=providers)
95
 
96
  for _ in tqdm(range(10)):
97
  x, mask, mu, t, spks, cond = get_dummy_input(random.randint(1, 6), random.randint(16, 512), out_channels, device)
 
107
  output_onnx = estimator_onnx.run(None, ort_inputs)[0]
108
  torch.testing.assert_allclose(output_pytorch, torch.from_numpy(output_onnx).to(device), rtol=1e-2, atol=1e-4)
109
 
110
+
111
  if __name__ == "__main__":
112
  main()
cosyvoice/bin/inference.py CHANGED
@@ -18,16 +18,15 @@ import argparse
18
  import logging
19
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
  import os
21
-
22
  import torch
23
  from torch.utils.data import DataLoader
24
  import torchaudio
25
  from hyperpyyaml import load_hyperpyyaml
26
  from tqdm import tqdm
27
  from cosyvoice.cli.model import CosyVoiceModel
28
-
29
  from cosyvoice.dataset.dataset import Dataset
30
 
 
31
  def get_args():
32
  parser = argparse.ArgumentParser(description='inference with your model')
33
  parser.add_argument('--config', required=True, help='config file')
@@ -66,7 +65,8 @@ def main():
66
  model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
67
  model.load(args.llm_model, args.flow_model, args.hifigan_model)
68
 
69
- test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False, tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
 
70
  test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
 
72
  del configs
@@ -74,13 +74,11 @@ def main():
74
  fn = os.path.join(args.result_dir, 'wav.scp')
75
  f = open(fn, 'w')
76
  with torch.no_grad():
77
- for batch_idx, batch in tqdm(enumerate(test_data_loader)):
78
  utts = batch["utts"]
79
  assert len(utts) == 1, "inference mode only support batchsize 1"
80
- text = batch["text"]
81
  text_token = batch["text_token"].to(device)
82
  text_token_len = batch["text_token_len"].to(device)
83
- tts_text = batch["tts_text"]
84
  tts_index = batch["tts_index"]
85
  tts_text_token = batch["tts_text_token"].to(device)
86
  tts_text_token_len = batch["tts_text_token_len"].to(device)
 
18
  import logging
19
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
  import os
 
21
  import torch
22
  from torch.utils.data import DataLoader
23
  import torchaudio
24
  from hyperpyyaml import load_hyperpyyaml
25
  from tqdm import tqdm
26
  from cosyvoice.cli.model import CosyVoiceModel
 
27
  from cosyvoice.dataset.dataset import Dataset
28
 
29
+
30
  def get_args():
31
  parser = argparse.ArgumentParser(description='inference with your model')
32
  parser.add_argument('--config', required=True, help='config file')
 
65
  model = CosyVoiceModel(configs['llm'], configs['flow'], configs['hift'])
66
  model.load(args.llm_model, args.flow_model, args.hifigan_model)
67
 
68
+ test_dataset = Dataset(args.prompt_data, data_pipeline=configs['data_pipeline'], mode='inference', shuffle=False, partition=False,
69
+ tts_file=args.tts_text, prompt_utt2data=args.prompt_utt2data)
70
  test_data_loader = DataLoader(test_dataset, batch_size=None, num_workers=0)
71
 
72
  del configs
 
74
  fn = os.path.join(args.result_dir, 'wav.scp')
75
  f = open(fn, 'w')
76
  with torch.no_grad():
77
+ for _, batch in tqdm(enumerate(test_data_loader)):
78
  utts = batch["utts"]
79
  assert len(utts) == 1, "inference mode only support batchsize 1"
 
80
  text_token = batch["text_token"].to(device)
81
  text_token_len = batch["text_token_len"].to(device)
 
82
  tts_index = batch["tts_index"]
83
  tts_text_token = batch["tts_text_token"].to(device)
84
  tts_text_token_len = batch["tts_text_token_len"].to(device)
cosyvoice/bin/train.py CHANGED
@@ -132,5 +132,6 @@ def main():
132
  executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
133
  dist.destroy_process_group(group_join)
134
 
 
135
  if __name__ == '__main__':
136
  main()
 
132
  executor.train_one_epoc(model, optimizer, scheduler, train_data_loader, cv_data_loader, writer, info_dict, group_join)
133
  dist.destroy_process_group(group_join)
134
 
135
+
136
  if __name__ == '__main__':
137
  main()
cosyvoice/cli/cosyvoice.py CHANGED
@@ -20,6 +20,7 @@ from cosyvoice.cli.frontend import CosyVoiceFrontEnd
20
  from cosyvoice.cli.model import CosyVoiceModel
21
  from cosyvoice.utils.file_utils import logging
22
 
 
23
  class CosyVoice:
24
 
25
  def __init__(self, model_dir, load_jit=True, load_onnx=True):
@@ -42,8 +43,8 @@ class CosyVoice:
42
  '{}/hift.pt'.format(model_dir))
43
  if load_jit:
44
  self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
45
- '{}/llm.llm.fp16.zip'.format(model_dir),
46
- '{}/flow.encoder.fp32.zip'.format(model_dir))
47
  if load_onnx:
48
  self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
49
  del configs
 
20
  from cosyvoice.cli.model import CosyVoiceModel
21
  from cosyvoice.utils.file_utils import logging
22
 
23
+
24
  class CosyVoice:
25
 
26
  def __init__(self, model_dir, load_jit=True, load_onnx=True):
 
43
  '{}/hift.pt'.format(model_dir))
44
  if load_jit:
45
  self.model.load_jit('{}/llm.text_encoder.fp16.zip'.format(model_dir),
46
+ '{}/llm.llm.fp16.zip'.format(model_dir),
47
+ '{}/flow.encoder.fp32.zip'.format(model_dir))
48
  if load_onnx:
49
  self.model.load_onnx('{}/flow.decoder.estimator.fp32.onnx'.format(model_dir))
50
  del configs
cosyvoice/cli/frontend.py CHANGED
@@ -50,7 +50,9 @@ class CosyVoiceFrontEnd:
50
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
51
  option.intra_op_num_threads = 1
52
  self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
53
- self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option, providers=["CUDAExecutionProvider"if torch.cuda.is_available() else "CPUExecutionProvider"])
 
 
54
  if os.path.exists(spk2info):
55
  self.spk2info = torch.load(spk2info, map_location=self.device)
56
  self.instruct = instruct
@@ -60,7 +62,8 @@ class CosyVoiceFrontEnd:
60
  if self.use_ttsfrd:
61
  self.frd = ttsfrd.TtsFrontendEngine()
62
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
63
- assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, 'failed to initialize ttsfrd resource'
 
64
  self.frd.set_lang_type('pinyin')
65
  self.frd.enable_pinyin_mix(True)
66
  self.frd.set_breakmodel_index(1)
@@ -76,8 +79,11 @@ class CosyVoiceFrontEnd:
76
 
77
  def _extract_speech_token(self, speech):
78
  feat = whisper.log_mel_spectrogram(speech, n_mels=128)
79
- speech_token = self.speech_tokenizer_session.run(None, {self.speech_tokenizer_session.get_inputs()[0].name: feat.detach().cpu().numpy(),
80
- self.speech_tokenizer_session.get_inputs()[1].name: np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
 
 
 
81
  speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
82
  speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
83
  return speech_token, speech_token_len
@@ -88,7 +94,8 @@ class CosyVoiceFrontEnd:
88
  dither=0,
89
  sample_frequency=16000)
90
  feat = feat - feat.mean(dim=0, keepdim=True)
91
- embedding = self.campplus_session.run(None, {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
 
92
  embedding = torch.tensor([embedding]).to(self.device)
93
  return embedding
94
 
@@ -112,18 +119,16 @@ class CosyVoiceFrontEnd:
112
  text = text.replace(" - ", ",")
113
  text = remove_bracket(text)
114
  text = re.sub(r'[,,]+$', '。', text)
115
- texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
116
- token_min_n=60, merge_len=20,
117
- comma_split=False)]
118
  else:
119
  if self.use_ttsfrd:
120
  text = self.frd.get_frd_extra_info(text, 'input')
121
  else:
122
  text = self.en_tn_model.normalize(text)
123
  text = spell_out_number(text, self.inflect_parser)
124
- texts = [i for i in split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
125
- token_min_n=60, merge_len=20,
126
- comma_split=False)]
127
  if split is False:
128
  return text
129
  return texts
 
50
  option.graph_optimization_level = onnxruntime.GraphOptimizationLevel.ORT_ENABLE_ALL
51
  option.intra_op_num_threads = 1
52
  self.campplus_session = onnxruntime.InferenceSession(campplus_model, sess_options=option, providers=["CPUExecutionProvider"])
53
+ self.speech_tokenizer_session = onnxruntime.InferenceSession(speech_tokenizer_model, sess_options=option,
54
+ providers=["CUDAExecutionProvider" if torch.cuda.is_available() else
55
+ "CPUExecutionProvider"])
56
  if os.path.exists(spk2info):
57
  self.spk2info = torch.load(spk2info, map_location=self.device)
58
  self.instruct = instruct
 
62
  if self.use_ttsfrd:
63
  self.frd = ttsfrd.TtsFrontendEngine()
64
  ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
65
+ assert self.frd.initialize('{}/../../pretrained_models/CosyVoice-ttsfrd/resource'.format(ROOT_DIR)) is True, \
66
+ 'failed to initialize ttsfrd resource'
67
  self.frd.set_lang_type('pinyin')
68
  self.frd.enable_pinyin_mix(True)
69
  self.frd.set_breakmodel_index(1)
 
79
 
80
  def _extract_speech_token(self, speech):
81
  feat = whisper.log_mel_spectrogram(speech, n_mels=128)
82
+ speech_token = self.speech_tokenizer_session.run(None,
83
+ {self.speech_tokenizer_session.get_inputs()[0].name:
84
+ feat.detach().cpu().numpy(),
85
+ self.speech_tokenizer_session.get_inputs()[1].name:
86
+ np.array([feat.shape[2]], dtype=np.int32)})[0].flatten().tolist()
87
  speech_token = torch.tensor([speech_token], dtype=torch.int32).to(self.device)
88
  speech_token_len = torch.tensor([speech_token.shape[1]], dtype=torch.int32).to(self.device)
89
  return speech_token, speech_token_len
 
94
  dither=0,
95
  sample_frequency=16000)
96
  feat = feat - feat.mean(dim=0, keepdim=True)
97
+ embedding = self.campplus_session.run(None,
98
+ {self.campplus_session.get_inputs()[0].name: feat.unsqueeze(dim=0).cpu().numpy()})[0].flatten().tolist()
99
  embedding = torch.tensor([embedding]).to(self.device)
100
  return embedding
101
 
 
119
  text = text.replace(" - ", ",")
120
  text = remove_bracket(text)
121
  text = re.sub(r'[,,]+$', '。', text)
122
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "zh", token_max_n=80,
123
+ token_min_n=60, merge_len=20, comma_split=False))
 
124
  else:
125
  if self.use_ttsfrd:
126
  text = self.frd.get_frd_extra_info(text, 'input')
127
  else:
128
  text = self.en_tn_model.normalize(text)
129
  text = spell_out_number(text, self.inflect_parser)
130
+ texts = list(split_paragraph(text, partial(self.tokenizer.encode, allowed_special=self.allowed_special), "en", token_max_n=80,
131
+ token_min_n=60, merge_len=20, comma_split=False))
 
132
  if split is False:
133
  return text
134
  return texts
cosyvoice/cli/model.py CHANGED
@@ -18,7 +18,7 @@ import time
18
  from contextlib import nullcontext
19
  import uuid
20
  from cosyvoice.utils.common import fade_in_out
21
- import numpy as np
22
 
23
  class CosyVoiceModel:
24
 
@@ -80,27 +80,27 @@ class CosyVoiceModel:
80
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
81
  with self.llm_context:
82
  for i in self.llm.inference(text=text.to(self.device),
83
- text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
84
- prompt_text=prompt_text.to(self.device),
85
- prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
86
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
87
- prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
88
- embedding=llm_embedding.to(self.device).half(),
89
- sampling=25,
90
- max_token_text_ratio=30,
91
- min_token_text_ratio=3):
92
  self.tts_speech_token_dict[uuid].append(i)
93
  self.llm_end_dict[uuid] = True
94
 
95
  def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
96
  with self.flow_hift_context:
97
  tts_mel = self.flow.inference(token=token.to(self.device),
98
- token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
99
- prompt_token=prompt_token.to(self.device),
100
- prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
101
- prompt_feat=prompt_feat.to(self.device),
102
- prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
103
- embedding=embedding.to(self.device))
104
  # mel overlap fade in out
105
  if self.mel_overlap_dict[uuid] is not None:
106
  tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
@@ -129,7 +129,8 @@ class CosyVoiceModel:
129
  # this_uuid is used to track variables related to this inference thread
130
  this_uuid = str(uuid.uuid1())
131
  with self.lock:
132
- self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid], self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = [], False, None, None
 
133
  p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
134
  p.start()
135
  if stream is True:
@@ -140,12 +141,12 @@ class CosyVoiceModel:
140
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
141
  with self.flow_hift_context:
142
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
143
- prompt_token=flow_prompt_speech_token,
144
- prompt_feat=prompt_speech_feat,
145
- embedding=flow_embedding,
146
- uuid=this_uuid,
147
- finalize=False)
148
- yield {'tts_speech': this_tts_speech.cpu()}
149
  with self.lock:
150
  self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
151
  # increase token_hop_len for better speech quality
@@ -157,11 +158,11 @@ class CosyVoiceModel:
157
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
158
  with self.flow_hift_context:
159
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
160
- prompt_token=flow_prompt_speech_token,
161
- prompt_feat=prompt_speech_feat,
162
- embedding=flow_embedding,
163
- uuid=this_uuid,
164
- finalize=True)
165
  yield {'tts_speech': this_tts_speech.cpu()}
166
  else:
167
  # deal with all tokens
@@ -169,11 +170,11 @@ class CosyVoiceModel:
169
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
170
  with self.flow_hift_context:
171
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
172
- prompt_token=flow_prompt_speech_token,
173
- prompt_feat=prompt_speech_feat,
174
- embedding=flow_embedding,
175
- uuid=this_uuid,
176
- finalize=True)
177
  yield {'tts_speech': this_tts_speech.cpu()}
178
  with self.lock:
179
  self.tts_speech_token_dict.pop(this_uuid)
 
18
  from contextlib import nullcontext
19
  import uuid
20
  from cosyvoice.utils.common import fade_in_out
21
+
22
 
23
  class CosyVoiceModel:
24
 
 
80
  def llm_job(self, text, prompt_text, llm_prompt_speech_token, llm_embedding, uuid):
81
  with self.llm_context:
82
  for i in self.llm.inference(text=text.to(self.device),
83
+ text_len=torch.tensor([text.shape[1]], dtype=torch.int32).to(self.device),
84
+ prompt_text=prompt_text.to(self.device),
85
+ prompt_text_len=torch.tensor([prompt_text.shape[1]], dtype=torch.int32).to(self.device),
86
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
87
+ prompt_speech_token_len=torch.tensor([llm_prompt_speech_token.shape[1]], dtype=torch.int32).to(self.device),
88
+ embedding=llm_embedding.to(self.device).half(),
89
+ sampling=25,
90
+ max_token_text_ratio=30,
91
+ min_token_text_ratio=3):
92
  self.tts_speech_token_dict[uuid].append(i)
93
  self.llm_end_dict[uuid] = True
94
 
95
  def token2wav(self, token, prompt_token, prompt_feat, embedding, uuid, finalize=False):
96
  with self.flow_hift_context:
97
  tts_mel = self.flow.inference(token=token.to(self.device),
98
+ token_len=torch.tensor([token.shape[1]], dtype=torch.int32).to(self.device),
99
+ prompt_token=prompt_token.to(self.device),
100
+ prompt_token_len=torch.tensor([prompt_token.shape[1]], dtype=torch.int32).to(self.device),
101
+ prompt_feat=prompt_feat.to(self.device),
102
+ prompt_feat_len=torch.tensor([prompt_feat.shape[1]], dtype=torch.int32).to(self.device),
103
+ embedding=embedding.to(self.device))
104
  # mel overlap fade in out
105
  if self.mel_overlap_dict[uuid] is not None:
106
  tts_mel = fade_in_out(tts_mel, self.mel_overlap_dict[uuid], self.mel_window)
 
129
  # this_uuid is used to track variables related to this inference thread
130
  this_uuid = str(uuid.uuid1())
131
  with self.lock:
132
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = [], False
133
+ self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
134
  p = threading.Thread(target=self.llm_job, args=(text, prompt_text, llm_prompt_speech_token, llm_embedding, this_uuid))
135
  p.start()
136
  if stream is True:
 
141
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
142
  with self.flow_hift_context:
143
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
144
+ prompt_token=flow_prompt_speech_token,
145
+ prompt_feat=prompt_speech_feat,
146
+ embedding=flow_embedding,
147
+ uuid=this_uuid,
148
+ finalize=False)
149
+ yield {'tts_speech': this_tts_speech.cpu()}
150
  with self.lock:
151
  self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
152
  # increase token_hop_len for better speech quality
 
158
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
159
  with self.flow_hift_context:
160
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
161
+ prompt_token=flow_prompt_speech_token,
162
+ prompt_feat=prompt_speech_feat,
163
+ embedding=flow_embedding,
164
+ uuid=this_uuid,
165
+ finalize=True)
166
  yield {'tts_speech': this_tts_speech.cpu()}
167
  else:
168
  # deal with all tokens
 
170
  this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
171
  with self.flow_hift_context:
172
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
173
+ prompt_token=flow_prompt_speech_token,
174
+ prompt_feat=prompt_speech_feat,
175
+ embedding=flow_embedding,
176
+ uuid=this_uuid,
177
+ finalize=True)
178
  yield {'tts_speech': this_tts_speech.cpu()}
179
  with self.lock:
180
  self.tts_speech_token_dict.pop(this_uuid)
cosyvoice/dataset/dataset.py CHANGED
@@ -148,7 +148,7 @@ def Dataset(data_list_file,
148
  tts_data = json.load(f)
149
  utt2lists = read_json_lists(prompt_utt2data)
150
  # filter unnecessary file in inference mode
151
- lists = list(set([utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists]))
152
  dataset = DataList(lists,
153
  shuffle=shuffle,
154
  partition=partition)
 
148
  tts_data = json.load(f)
149
  utt2lists = read_json_lists(prompt_utt2data)
150
  # filter unnecessary file in inference mode
151
+ lists = list({utt2lists[utt] for utt in tts_data.keys() if utt2lists[utt] in lists})
152
  dataset = DataList(lists,
153
  shuffle=shuffle,
154
  partition=partition)
cosyvoice/dataset/processor.py CHANGED
@@ -23,7 +23,7 @@ import torch.nn.functional as F
23
 
24
  torchaudio.set_audio_backend('soundfile')
25
 
26
- AUDIO_FORMAT_SETS = set(['flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'])
27
 
28
 
29
  def parquet_opener(data, mode='train', tts_data={}):
@@ -54,6 +54,7 @@ def parquet_opener(data, mode='train', tts_data={}):
54
  except Exception as ex:
55
  logging.warning('Failed to open {}, ex info {}'.format(url, ex))
56
 
 
57
  def filter(data,
58
  max_length=10240,
59
  min_length=10,
 
23
 
24
  torchaudio.set_audio_backend('soundfile')
25
 
26
+ AUDIO_FORMAT_SETS = {'flac', 'mp3', 'm4a', 'ogg', 'opus', 'wav', 'wma'}
27
 
28
 
29
  def parquet_opener(data, mode='train', tts_data={}):
 
54
  except Exception as ex:
55
  logging.warning('Failed to open {}, ex info {}'.format(url, ex))
56
 
57
+
58
  def filter(data,
59
  max_length=10240,
60
  min_length=10,
cosyvoice/flow/decoder.py CHANGED
@@ -74,7 +74,7 @@ class ConditionalDecoder(nn.Module):
74
  )
75
  self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
76
 
77
- for i in range(num_mid_blocks):
78
  input_channel = channels[-1]
79
  out_channels = channels[-1]
80
  resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
@@ -126,7 +126,6 @@ class ConditionalDecoder(nn.Module):
126
  self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
127
  self.initialize_weights()
128
 
129
-
130
  def initialize_weights(self):
131
  for m in self.modules():
132
  if isinstance(m, nn.Conv1d):
 
74
  )
75
  self.down_blocks.append(nn.ModuleList([resnet, transformer_blocks, downsample]))
76
 
77
+ for _ in range(num_mid_blocks):
78
  input_channel = channels[-1]
79
  out_channels = channels[-1]
80
  resnet = ResnetBlock1D(dim=input_channel, dim_out=output_channel, time_emb_dim=time_embed_dim)
 
126
  self.final_proj = nn.Conv1d(channels[-1], self.out_channels, 1)
127
  self.initialize_weights()
128
 
 
129
  def initialize_weights(self):
130
  for m in self.modules():
131
  if isinstance(m, nn.Conv1d):
cosyvoice/flow/flow.py CHANGED
@@ -33,8 +33,13 @@ class MaskedDiffWithXvec(torch.nn.Module):
33
  encoder: torch.nn.Module = None,
34
  length_regulator: torch.nn.Module = None,
35
  decoder: torch.nn.Module = None,
36
- decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1, 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine', 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}), 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64, 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
37
- mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050, 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
 
 
 
 
 
38
  super().__init__()
39
  self.input_size = input_size
40
  self.output_size = output_size
 
33
  encoder: torch.nn.Module = None,
34
  length_regulator: torch.nn.Module = None,
35
  decoder: torch.nn.Module = None,
36
+ decoder_conf: Dict = {'in_channels': 240, 'out_channel': 80, 'spk_emb_dim': 80, 'n_spks': 1,
37
+ 'cfm_params': DictConfig({'sigma_min': 1e-06, 'solver': 'euler', 't_scheduler': 'cosine',
38
+ 'training_cfg_rate': 0.2, 'inference_cfg_rate': 0.7, 'reg_loss_type': 'l1'}),
39
+ 'decoder_params': {'channels': [256, 256], 'dropout': 0.0, 'attention_head_dim': 64,
40
+ 'n_blocks': 4, 'num_mid_blocks': 12, 'num_heads': 8, 'act_fn': 'gelu'}},
41
+ mel_feat_conf: Dict = {'n_fft': 1024, 'num_mels': 80, 'sampling_rate': 22050,
42
+ 'hop_size': 256, 'win_size': 1024, 'fmin': 0, 'fmax': 8000}):
43
  super().__init__()
44
  self.input_size = input_size
45
  self.output_size = output_size
cosyvoice/flow/flow_matching.py CHANGED
@@ -15,6 +15,7 @@ import torch
15
  import torch.nn.functional as F
16
  from matcha.models.components.flow_matching import BASECFM
17
 
 
18
  class ConditionalCFM(BASECFM):
19
  def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
20
  super().__init__(
 
15
  import torch.nn.functional as F
16
  from matcha.models.components.flow_matching import BASECFM
17
 
18
+
19
  class ConditionalCFM(BASECFM):
20
  def __init__(self, in_channels, cfm_params, n_spks=1, spk_emb_dim=64, estimator: torch.nn.Module = None):
21
  super().__init__(
cosyvoice/flow/length_regulator.py CHANGED
File without changes
cosyvoice/hifigan/f0_predictor.py CHANGED
File without changes
cosyvoice/hifigan/generator.py CHANGED
@@ -38,6 +38,8 @@ This code is modified from https://github.com/jik876/hifi-gan
38
  https://github.com/NVIDIA/BigVGAN
39
 
40
  """
 
 
41
  class ResBlock(torch.nn.Module):
42
  """Residual block module in HiFiGAN/BigVGAN."""
43
  def __init__(
@@ -100,6 +102,7 @@ class ResBlock(torch.nn.Module):
100
  remove_weight_norm(self.convs1[idx])
101
  remove_weight_norm(self.convs2[idx])
102
 
 
103
  class SineGen(torch.nn.Module):
104
  """ Definition of sine generator
105
  SineGen(samp_rate, harmonic_num = 0,
@@ -286,8 +289,7 @@ class HiFTGenerator(nn.Module):
286
  self.source_resblocks = nn.ModuleList()
287
  downsample_rates = [1] + upsample_rates[::-1][:-1]
288
  downsample_cum_rates = np.cumprod(downsample_rates)
289
- for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes,
290
- source_resblock_dilation_sizes)):
291
  if u == 1:
292
  self.source_downs.append(
293
  Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
@@ -304,7 +306,7 @@ class HiFTGenerator(nn.Module):
304
  self.resblocks = nn.ModuleList()
305
  for i in range(len(self.ups)):
306
  ch = base_channels // (2**(i + 1))
307
- for j, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
308
  self.resblocks.append(ResBlock(ch, k, d))
309
 
310
  self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
@@ -332,7 +334,8 @@ class HiFTGenerator(nn.Module):
332
  magnitude = torch.clip(magnitude, max=1e2)
333
  real = magnitude * torch.cos(phase)
334
  img = magnitude * torch.sin(phase)
335
- inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"], self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
 
336
  return inverse_transform
337
 
338
  def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
 
38
  https://github.com/NVIDIA/BigVGAN
39
 
40
  """
41
+
42
+
43
  class ResBlock(torch.nn.Module):
44
  """Residual block module in HiFiGAN/BigVGAN."""
45
  def __init__(
 
102
  remove_weight_norm(self.convs1[idx])
103
  remove_weight_norm(self.convs2[idx])
104
 
105
+
106
  class SineGen(torch.nn.Module):
107
  """ Definition of sine generator
108
  SineGen(samp_rate, harmonic_num = 0,
 
289
  self.source_resblocks = nn.ModuleList()
290
  downsample_rates = [1] + upsample_rates[::-1][:-1]
291
  downsample_cum_rates = np.cumprod(downsample_rates)
292
+ for i, (u, k, d) in enumerate(zip(downsample_cum_rates[::-1], source_resblock_kernel_sizes, source_resblock_dilation_sizes)):
 
293
  if u == 1:
294
  self.source_downs.append(
295
  Conv1d(istft_params["n_fft"] + 2, base_channels // (2 ** (i + 1)), 1, 1)
 
306
  self.resblocks = nn.ModuleList()
307
  for i in range(len(self.ups)):
308
  ch = base_channels // (2**(i + 1))
309
+ for _, (k, d) in enumerate(zip(resblock_kernel_sizes, resblock_dilation_sizes)):
310
  self.resblocks.append(ResBlock(ch, k, d))
311
 
312
  self.conv_post = weight_norm(Conv1d(ch, istft_params["n_fft"] + 2, 7, 1, padding=3))
 
334
  magnitude = torch.clip(magnitude, max=1e2)
335
  real = magnitude * torch.cos(phase)
336
  img = magnitude * torch.sin(phase)
337
+ inverse_transform = torch.istft(torch.complex(real, img), self.istft_params["n_fft"], self.istft_params["hop_len"],
338
+ self.istft_params["n_fft"], window=self.stft_window.to(magnitude.device))
339
  return inverse_transform
340
 
341
  def forward(self, x: torch.Tensor, cache_source: torch.Tensor = torch.zeros(1, 1, 0)) -> torch.Tensor:
cosyvoice/llm/llm.py CHANGED
@@ -80,7 +80,8 @@ class TransformerLM(torch.nn.Module):
80
  def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
81
  text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
82
  speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
83
- lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0) for i in range(len(text_token))]
 
84
  lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
85
  lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
86
  return lm_input, lm_input_len
@@ -104,7 +105,8 @@ class TransformerLM(torch.nn.Module):
104
  embedding = batch['embedding'].to(device)
105
 
106
  # 1. prepare llm_target
107
- lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() + [self.speech_token_size]) for i in range(text_token.size(0))]
 
108
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
109
 
110
  # 1. encode text_token
@@ -124,7 +126,8 @@ class TransformerLM(torch.nn.Module):
124
  speech_token = self.speech_embedding(speech_token)
125
 
126
  # 5. unpad and pad
127
- lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len)
 
128
 
129
  # 6. run lm forward
130
  lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
@@ -194,8 +197,10 @@ class TransformerLM(torch.nn.Module):
194
  offset = 0
195
  att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
196
  for i in range(max_len):
197
- y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1, att_cache=att_cache, cnn_cache=cnn_cache,
198
- att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]), device=lm_input.device)).to(torch.bool))
 
 
199
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
200
  top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
201
  if top_ids == self.speech_token_size:
 
80
  def pad_unpad_sequence(self, sos_eos_emb, embedding, text_token, text_token_len, task_id_emb, speech_token, speech_token_len):
81
  text_token = unpad_sequence(text_token, text_token_len.cpu(), batch_first=True)
82
  speech_token = unpad_sequence(speech_token, speech_token_len.cpu(), batch_first=True)
83
+ lm_input = [torch.concat([sos_eos_emb.squeeze(dim=0), embedding[i], text_token[i], task_id_emb.squeeze(dim=0), speech_token[i]], dim=0)
84
+ for i in range(len(text_token))]
85
  lm_input_len = torch.tensor([i.size(0) for i in lm_input], dtype=torch.int32)
86
  lm_input = pad_sequence(lm_input, batch_first=True, padding_value=IGNORE_ID)
87
  return lm_input, lm_input_len
 
105
  embedding = batch['embedding'].to(device)
106
 
107
  # 1. prepare llm_target
108
+ lm_target = [torch.tensor([IGNORE_ID] * (2 + text_token_len[i]) + speech_token[i, :speech_token_len[i]].tolist() +
109
+ [self.speech_token_size]) for i in range(text_token.size(0))]
110
  lm_target = pad_sequence(lm_target, batch_first=True, padding_value=IGNORE_ID).to(device)
111
 
112
  # 1. encode text_token
 
126
  speech_token = self.speech_embedding(speech_token)
127
 
128
  # 5. unpad and pad
129
+ lm_input, lm_input_len = self.pad_unpad_sequence(sos_eos_emb, embedding, text_token, text_token_len,
130
+ task_id_emb, speech_token, speech_token_len)
131
 
132
  # 6. run lm forward
133
  lm_output, lm_output_mask = self.llm(lm_input, lm_input_len.to(device))
 
197
  offset = 0
198
  att_cache, cnn_cache = torch.zeros((0, 0, 0, 0), device=lm_input.device), torch.zeros((0, 0, 0, 0), device=lm_input.device)
199
  for i in range(max_len):
200
+ y_pred, att_cache, cnn_cache = self.llm.forward_chunk(lm_input, offset=0, required_cache_size=-1,
201
+ att_cache=att_cache, cnn_cache=cnn_cache,
202
+ att_mask=torch.tril(torch.ones((1, lm_input.shape[1], lm_input.shape[1]),
203
+ device=lm_input.device)).to(torch.bool))
204
  logp = self.llm_decoder(y_pred[:, -1]).log_softmax(dim=-1)
205
  top_ids = self.sampling_ids(logp.squeeze(dim=0), out_tokens, sampling, ignore_eos=True if i < min_len else False).item()
206
  if top_ids == self.speech_token_size:
cosyvoice/transformer/embedding.py CHANGED
@@ -212,7 +212,7 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
212
 
213
  """
214
 
215
- def __init__(self, d_model: int, dropout_rate: float, max_len: int=5000):
216
  """Construct an PositionalEncoding object."""
217
  super(EspnetRelPositionalEncoding, self).__init__()
218
  self.d_model = d_model
@@ -289,6 +289,6 @@ class EspnetRelPositionalEncoding(torch.nn.Module):
289
  """
290
  pos_emb = self.pe[
291
  :,
292
- self.pe.size(1) // 2 - size + 1 : self.pe.size(1) // 2 + size,
293
  ]
294
  return pos_emb
 
212
 
213
  """
214
 
215
+ def __init__(self, d_model: int, dropout_rate: float, max_len: int = 5000):
216
  """Construct an PositionalEncoding object."""
217
  super(EspnetRelPositionalEncoding, self).__init__()
218
  self.d_model = d_model
 
289
  """
290
  pos_emb = self.pe[
291
  :,
292
+ self.pe.size(1) // 2 - size + 1: self.pe.size(1) // 2 + size,
293
  ]
294
  return pos_emb
cosyvoice/utils/common.py CHANGED
@@ -102,6 +102,7 @@ def init_weights(m, mean=0.0, std=0.01):
102
  if classname.find("Conv") != -1:
103
  m.weight.data.normal_(mean, std)
104
 
 
105
  # Repetition Aware Sampling in VALL-E 2
106
  def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
107
  top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
@@ -110,6 +111,7 @@ def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25,
110
  top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
111
  return top_ids
112
 
 
113
  def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
114
  prob, indices = [], []
115
  cum_prob = 0.0
@@ -127,13 +129,16 @@ def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
127
  top_ids = indices[prob.multinomial(1, replacement=True)]
128
  return top_ids
129
 
 
130
  def random_sampling(weighted_scores, decoded_tokens, sampling):
131
  top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
132
  return top_ids
133
 
 
134
  def fade_in_out(fade_in_mel, fade_out_mel, window):
135
  device = fade_in_mel.device
136
  fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
137
  mel_overlap_len = int(window.shape[0] / 2)
138
- fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
 
139
  return fade_in_mel.to(device)
 
102
  if classname.find("Conv") != -1:
103
  m.weight.data.normal_(mean, std)
104
 
105
+
106
  # Repetition Aware Sampling in VALL-E 2
107
  def ras_sampling(weighted_scores, decoded_tokens, sampling, top_p=0.8, top_k=25, win_size=10, tau_r=0.1):
108
  top_ids = nucleus_sampling(weighted_scores, top_p=top_p, top_k=top_k)
 
111
  top_ids = random_sampling(weighted_scores, decoded_tokens, sampling)
112
  return top_ids
113
 
114
+
115
  def nucleus_sampling(weighted_scores, top_p=0.8, top_k=25):
116
  prob, indices = [], []
117
  cum_prob = 0.0
 
129
  top_ids = indices[prob.multinomial(1, replacement=True)]
130
  return top_ids
131
 
132
+
133
  def random_sampling(weighted_scores, decoded_tokens, sampling):
134
  top_ids = weighted_scores.softmax(dim=0).multinomial(1, replacement=True)
135
  return top_ids
136
 
137
+
138
  def fade_in_out(fade_in_mel, fade_out_mel, window):
139
  device = fade_in_mel.device
140
  fade_in_mel, fade_out_mel = fade_in_mel.cpu(), fade_out_mel.cpu()
141
  mel_overlap_len = int(window.shape[0] / 2)
142
+ fade_in_mel[:, :, :mel_overlap_len] = fade_in_mel[:, :, :mel_overlap_len] * window[:mel_overlap_len] + \
143
+ fade_out_mel[:, :, -mel_overlap_len:] * window[mel_overlap_len:]
144
  return fade_in_mel.to(device)
cosyvoice/utils/executor.py CHANGED
@@ -70,7 +70,8 @@ class Executor:
70
  info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
71
  log_per_step(writer, info_dict)
72
  # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
73
- if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and (batch_idx + 1) % info_dict["accum_grad"] == 0:
 
74
  dist.barrier()
75
  self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
76
  model.train()
 
70
  info_dict = update_parameter_and_lr(model, optimizer, scheduler, info_dict)
71
  log_per_step(writer, info_dict)
72
  # NOTE specify save_per_step in cosyvoice.yaml if you want to enable step save
73
+ if info_dict['save_per_step'] > 0 and (self.step + 1) % info_dict['save_per_step'] == 0 and \
74
+ (batch_idx + 1) % info_dict["accum_grad"] == 0:
75
  dist.barrier()
76
  self.cv(model, cv_data_loader, writer, info_dict, on_batch_end=False)
77
  model.train()
cosyvoice/utils/file_utils.py CHANGED
@@ -28,6 +28,7 @@ def read_lists(list_file):
28
  lists.append(line.strip())
29
  return lists
30
 
 
31
  def read_json_lists(list_file):
32
  lists = read_lists(list_file)
33
  results = {}
@@ -36,6 +37,7 @@ def read_json_lists(list_file):
36
  results.update(json.load(fin))
37
  return results
38
 
 
39
  def load_wav(wav, target_sr):
40
  speech, sample_rate = torchaudio.load(wav)
41
  speech = speech.mean(dim=0, keepdim=True)
@@ -44,6 +46,7 @@ def load_wav(wav, target_sr):
44
  speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
45
  return speech
46
 
 
47
  def speed_change(waveform, sample_rate, speed_factor: str):
48
  effects = [
49
  ["tempo", speed_factor], # speed_factor
 
28
  lists.append(line.strip())
29
  return lists
30
 
31
+
32
  def read_json_lists(list_file):
33
  lists = read_lists(list_file)
34
  results = {}
 
37
  results.update(json.load(fin))
38
  return results
39
 
40
+
41
  def load_wav(wav, target_sr):
42
  speech, sample_rate = torchaudio.load(wav)
43
  speech = speech.mean(dim=0, keepdim=True)
 
46
  speech = torchaudio.transforms.Resample(orig_freq=sample_rate, new_freq=target_sr)(speech)
47
  return speech
48
 
49
+
50
  def speed_change(waveform, sample_rate, speed_factor: str):
51
  effects = [
52
  ["tempo", speed_factor], # speed_factor
cosyvoice/utils/frontend_utils.py CHANGED
@@ -15,6 +15,7 @@
15
  import re
16
  chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
17
 
 
18
  # whether contain chinese character
19
  def contains_chinese(text):
20
  return bool(chinese_char_pattern.search(text))
 
15
  import re
16
  chinese_char_pattern = re.compile(r'[\u4e00-\u9fff]+')
17
 
18
+
19
  # whether contain chinese character
20
  def contains_chinese(text):
21
  return bool(chinese_char_pattern.search(text))
cosyvoice/utils/scheduler.py CHANGED
@@ -567,8 +567,7 @@ class NoamAnnealing(_LRScheduler):
567
  min_lr=0.0,
568
  last_epoch=-1):
569
  self._normalize = d_model**(-0.5)
570
- assert not (warmup_steps is not None
571
- and warmup_ratio is not None), \
572
  "Either use particular number of step or ratio"
573
  assert warmup_ratio is None or max_steps is not None, \
574
  "If there is a ratio, there should be a total steps"
 
567
  min_lr=0.0,
568
  last_epoch=-1):
569
  self._normalize = d_model**(-0.5)
570
+ assert not (warmup_steps is not None and warmup_ratio is not None), \
 
571
  "Either use particular number of step or ratio"
572
  assert warmup_ratio is None or max_steps is not None, \
573
  "If there is a ratio, there should be a total steps"
cosyvoice/utils/train_utils.py CHANGED
@@ -69,7 +69,6 @@ def init_dataset_and_dataloader(args, configs):
69
  return train_dataset, cv_dataset, train_data_loader, cv_data_loader
70
 
71
 
72
-
73
  def check_modify_and_save_config(args, configs):
74
  if args.train_engine == "torch_ddp":
75
  configs['train_conf']["dtype"] = 'fp32'
@@ -84,7 +83,8 @@ def check_modify_and_save_config(args, configs):
84
  configs['train_conf']["dtype"] = "fp32"
85
  assert ds_configs["train_micro_batch_size_per_gpu"] == 1
86
  # if use deepspeed, override ddp config
87
- configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] * configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
 
88
  configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
89
  configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
90
  configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
 
69
  return train_dataset, cv_dataset, train_data_loader, cv_data_loader
70
 
71
 
 
72
  def check_modify_and_save_config(args, configs):
73
  if args.train_engine == "torch_ddp":
74
  configs['train_conf']["dtype"] = 'fp32'
 
83
  configs['train_conf']["dtype"] = "fp32"
84
  assert ds_configs["train_micro_batch_size_per_gpu"] == 1
85
  # if use deepspeed, override ddp config
86
+ configs['train_conf']['save_per_step'] = int(configs['train_conf']['save_per_step'] *
87
+ configs['train_conf']['accum_grad'] / ds_configs["gradient_accumulation_steps"])
88
  configs['train_conf']['accum_grad'] = ds_configs["gradient_accumulation_steps"]
89
  configs['train_conf']['grad_clip'] = ds_configs["gradient_clipping"]
90
  configs['train_conf']['log_interval'] = ds_configs["steps_per_print"]
examples/libritts/cosyvoice/local/prepare_data.py CHANGED
@@ -7,6 +7,7 @@ from tqdm import tqdm
7
 
8
  logger = logging.getLogger()
9
 
 
10
  def main():
11
  wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
12
 
@@ -41,6 +42,7 @@ def main():
41
  f.write('{} {}\n'.format(k, ' '.join(v)))
42
  return
43
 
 
44
  if __name__ == "__main__":
45
  parser = argparse.ArgumentParser()
46
  parser.add_argument('--src_dir',
 
7
 
8
  logger = logging.getLogger()
9
 
10
+
11
  def main():
12
  wavs = list(glob.glob('{}/*/*/*wav'.format(args.src_dir)))
13
 
 
42
  f.write('{} {}\n'.format(k, ' '.join(v)))
43
  return
44
 
45
+
46
  if __name__ == "__main__":
47
  parser = argparse.ArgumentParser()
48
  parser.add_argument('--src_dir',
examples/libritts/cosyvoice/run.sh CHANGED
@@ -83,7 +83,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
83
  fi
84
  cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
85
  cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
86
- for model in llm; do
87
  torchrun --nnodes=1 --nproc_per_node=$num_gpus \
88
  --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
89
  cosyvoice/bin/train.py \
 
83
  fi
84
  cat data/{train-clean-100,train-clean-360,train-other-500}/parquet/data.list > data/train.data.list
85
  cat data/{dev-clean,dev-other}/parquet/data.list > data/dev.data.list
86
+ for model in llm flow; do
87
  torchrun --nnodes=1 --nproc_per_node=$num_gpus \
88
  --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
89
  cosyvoice/bin/train.py \
examples/magicdata-read/cosyvoice/local/prepare_data.py CHANGED
@@ -6,6 +6,7 @@ from tqdm import tqdm
6
 
7
  logger = logging.getLogger()
8
 
 
9
  def main():
10
  utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
11
  with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
@@ -40,6 +41,7 @@ def main():
40
  f.write('{} {}\n'.format(k, ' '.join(v)))
41
  return
42
 
 
43
  if __name__ == "__main__":
44
  parser = argparse.ArgumentParser()
45
  parser.add_argument('--src_dir',
 
6
 
7
  logger = logging.getLogger()
8
 
9
+
10
  def main():
11
  utt2wav, utt2text, utt2spk, spk2utt = {}, {}, {}, {}
12
  with open(os.path.join(args.src_dir, "TRANS.txt"), "r") as f:
 
41
  f.write('{} {}\n'.format(k, ' '.join(v)))
42
  return
43
 
44
+
45
  if __name__ == "__main__":
46
  parser = argparse.ArgumentParser()
47
  parser.add_argument('--src_dir',
examples/magicdata-read/cosyvoice/run.sh CHANGED
@@ -83,7 +83,7 @@ if [ ${stage} -le 5 ] && [ ${stop_stage} -ge 5 ]; then
83
  fi
84
  cp data/train/parquet/data.list data/train.data.list
85
  cp data/dev/parquet/data.list data/dev.data.list
86
- for model in llm; do
87
  torchrun --nnodes=1 --nproc_per_node=$num_gpus \
88
  --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
89
  cosyvoice/bin/train.py \
 
83
  fi
84
  cp data/train/parquet/data.list data/train.data.list
85
  cp data/dev/parquet/data.list data/dev.data.list
86
+ for model in llm flow; do
87
  torchrun --nnodes=1 --nproc_per_node=$num_gpus \
88
  --rdzv_id=$job_id --rdzv_backend="c10d" --rdzv_endpoint="localhost:0" \
89
  cosyvoice/bin/train.py \
runtime/python/fastapi/client.py CHANGED
@@ -38,7 +38,7 @@ def main():
38
  payload = {
39
  'tts_text': args.tts_text,
40
  }
41
- files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav,'rb'), 'application/octet-stream'))]
42
  response = requests.request("GET", url, data=payload, files=files, stream=True)
43
  else:
44
  payload = {
@@ -55,6 +55,7 @@ def main():
55
  torchaudio.save(args.tts_wav, tts_speech, target_sr)
56
  logging.info('get response')
57
 
 
58
  if __name__ == "__main__":
59
  parser = argparse.ArgumentParser()
60
  parser.add_argument('--host',
@@ -81,7 +82,8 @@ if __name__ == "__main__":
81
  default='../../../zero_shot_prompt.wav')
82
  parser.add_argument('--instruct_text',
83
  type=str,
84
- default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
 
85
  parser.add_argument('--tts_wav',
86
  type=str,
87
  default='demo.wav')
 
38
  payload = {
39
  'tts_text': args.tts_text,
40
  }
41
+ files = [('prompt_wav', ('prompt_wav', open(args.prompt_wav, 'rb'), 'application/octet-stream'))]
42
  response = requests.request("GET", url, data=payload, files=files, stream=True)
43
  else:
44
  payload = {
 
55
  torchaudio.save(args.tts_wav, tts_speech, target_sr)
56
  logging.info('get response')
57
 
58
+
59
  if __name__ == "__main__":
60
  parser = argparse.ArgumentParser()
61
  parser.add_argument('--host',
 
82
  default='../../../zero_shot_prompt.wav')
83
  parser.add_argument('--instruct_text',
84
  type=str,
85
+ default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
86
+ Fights with fervor for justice, but struggles with impulsiveness.')
87
  parser.add_argument('--tts_wav',
88
  type=str,
89
  default='demo.wav')
runtime/python/fastapi/server.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  import os
15
  import sys
16
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
17
- sys.path.append('{}/../../..'.format(ROOT_DIR))
18
- sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
19
  import argparse
20
  import logging
21
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
@@ -24,6 +21,9 @@ from fastapi.responses import StreamingResponse
24
  from fastapi.middleware.cors import CORSMiddleware
25
  import uvicorn
26
  import numpy as np
 
 
 
27
  from cosyvoice.cli.cosyvoice import CosyVoice
28
  from cosyvoice.utils.file_utils import load_wav
29
 
@@ -36,34 +36,40 @@ app.add_middleware(
36
  allow_methods=["*"],
37
  allow_headers=["*"])
38
 
 
39
  def generate_data(model_output):
40
  for i in model_output:
41
  tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
42
  yield tts_audio
43
 
 
44
  @app.get("/inference_sft")
45
  async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
46
  model_output = cosyvoice.inference_sft(tts_text, spk_id)
47
  return StreamingResponse(generate_data(model_output))
48
 
 
49
  @app.get("/inference_zero_shot")
50
  async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
51
  prompt_speech_16k = load_wav(prompt_wav.file, 16000)
52
  model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
53
  return StreamingResponse(generate_data(model_output))
54
 
 
55
  @app.get("/inference_cross_lingual")
56
  async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
57
  prompt_speech_16k = load_wav(prompt_wav.file, 16000)
58
  model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
59
  return StreamingResponse(generate_data(model_output))
60
 
 
61
  @app.get("/inference_instruct")
62
  async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
63
  model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
64
  return StreamingResponse(generate_data(model_output))
65
 
66
- if __name__=='__main__':
 
67
  parser = argparse.ArgumentParser()
68
  parser.add_argument('--port',
69
  type=int,
@@ -74,4 +80,4 @@ if __name__=='__main__':
74
  help='local path or modelscope repo id')
75
  args = parser.parse_args()
76
  cosyvoice = CosyVoice(args.model_dir)
77
- uvicorn.run(app, host="127.0.0.1", port=args.port)
 
13
  # limitations under the License.
14
  import os
15
  import sys
 
 
 
16
  import argparse
17
  import logging
18
  logging.getLogger('matplotlib').setLevel(logging.WARNING)
 
21
  from fastapi.middleware.cors import CORSMiddleware
22
  import uvicorn
23
  import numpy as np
24
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
25
+ sys.path.append('{}/../../..'.format(ROOT_DIR))
26
+ sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
27
  from cosyvoice.cli.cosyvoice import CosyVoice
28
  from cosyvoice.utils.file_utils import load_wav
29
 
 
36
  allow_methods=["*"],
37
  allow_headers=["*"])
38
 
39
+
40
  def generate_data(model_output):
41
  for i in model_output:
42
  tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
43
  yield tts_audio
44
 
45
+
46
  @app.get("/inference_sft")
47
  async def inference_sft(tts_text: str = Form(), spk_id: str = Form()):
48
  model_output = cosyvoice.inference_sft(tts_text, spk_id)
49
  return StreamingResponse(generate_data(model_output))
50
 
51
+
52
  @app.get("/inference_zero_shot")
53
  async def inference_zero_shot(tts_text: str = Form(), prompt_text: str = Form(), prompt_wav: UploadFile = File()):
54
  prompt_speech_16k = load_wav(prompt_wav.file, 16000)
55
  model_output = cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k)
56
  return StreamingResponse(generate_data(model_output))
57
 
58
+
59
  @app.get("/inference_cross_lingual")
60
  async def inference_cross_lingual(tts_text: str = Form(), prompt_wav: UploadFile = File()):
61
  prompt_speech_16k = load_wav(prompt_wav.file, 16000)
62
  model_output = cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k)
63
  return StreamingResponse(generate_data(model_output))
64
 
65
+
66
  @app.get("/inference_instruct")
67
  async def inference_instruct(tts_text: str = Form(), spk_id: str = Form(), instruct_text: str = Form()):
68
  model_output = cosyvoice.inference_instruct(tts_text, spk_id, instruct_text)
69
  return StreamingResponse(generate_data(model_output))
70
 
71
+
72
+ if __name__ == '__main__':
73
  parser = argparse.ArgumentParser()
74
  parser.add_argument('--port',
75
  type=int,
 
80
  help='local path or modelscope repo id')
81
  args = parser.parse_args()
82
  cosyvoice = CosyVoice(args.model_dir)
83
+ uvicorn.run(app, host="127.0.0.1", port=args.port)
runtime/python/grpc/client.py CHANGED
@@ -96,7 +96,8 @@ if __name__ == "__main__":
96
  default='../../../zero_shot_prompt.wav')
97
  parser.add_argument('--instruct_text',
98
  type=str,
99
- default='Theo \'Crimson\', is a fiery, passionate rebel leader. Fights with fervor for justice, but struggles with impulsiveness.')
 
100
  parser.add_argument('--tts_wav',
101
  type=str,
102
  default='demo.wav')
 
96
  default='../../../zero_shot_prompt.wav')
97
  parser.add_argument('--instruct_text',
98
  type=str,
99
+ default='Theo \'Crimson\', is a fiery, passionate rebel leader. \
100
+ Fights with fervor for justice, but struggles with impulsiveness.')
101
  parser.add_argument('--tts_wav',
102
  type=str,
103
  default='demo.wav')
runtime/python/grpc/server.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  import os
15
  import sys
16
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
17
- sys.path.append('{}/../../..'.format(ROOT_DIR))
18
- sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
19
  from concurrent import futures
20
  import argparse
21
  import cosyvoice_pb2
@@ -25,11 +22,15 @@ logging.getLogger('matplotlib').setLevel(logging.WARNING)
25
  import grpc
26
  import torch
27
  import numpy as np
 
 
 
28
  from cosyvoice.cli.cosyvoice import CosyVoice
29
 
30
  logging.basicConfig(level=logging.DEBUG,
31
  format='%(asctime)s %(levelname)s %(message)s')
32
 
 
33
  class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
34
  def __init__(self, args):
35
  self.cosyvoice = CosyVoice(args.model_dir)
@@ -43,7 +44,9 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
43
  logging.info('get zero_shot inference request')
44
  prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
45
  prompt_speech_16k = prompt_speech_16k.float() / (2**15)
46
- model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text, request.zero_shot_request.prompt_text, prompt_speech_16k)
 
 
47
  elif request.HasField('cross_lingual_request'):
48
  logging.info('get cross_lingual inference request')
49
  prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
@@ -51,7 +54,9 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
51
  model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
52
  else:
53
  logging.info('get instruct inference request')
54
- model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text, request.instruct_request.spk_id, request.instruct_request.instruct_text)
 
 
55
 
56
  logging.info('send inference response')
57
  for i in model_output:
@@ -59,6 +64,7 @@ class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
59
  response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
60
  yield response
61
 
 
62
  def main():
63
  grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
64
  cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
 
13
  # limitations under the License.
14
  import os
15
  import sys
 
 
 
16
  from concurrent import futures
17
  import argparse
18
  import cosyvoice_pb2
 
22
  import grpc
23
  import torch
24
  import numpy as np
25
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
26
+ sys.path.append('{}/../../..'.format(ROOT_DIR))
27
+ sys.path.append('{}/../../../third_party/Matcha-TTS'.format(ROOT_DIR))
28
  from cosyvoice.cli.cosyvoice import CosyVoice
29
 
30
  logging.basicConfig(level=logging.DEBUG,
31
  format='%(asctime)s %(levelname)s %(message)s')
32
 
33
+
34
  class CosyVoiceServiceImpl(cosyvoice_pb2_grpc.CosyVoiceServicer):
35
  def __init__(self, args):
36
  self.cosyvoice = CosyVoice(args.model_dir)
 
44
  logging.info('get zero_shot inference request')
45
  prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.zero_shot_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
46
  prompt_speech_16k = prompt_speech_16k.float() / (2**15)
47
+ model_output = self.cosyvoice.inference_zero_shot(request.zero_shot_request.tts_text,
48
+ request.zero_shot_request.prompt_text,
49
+ prompt_speech_16k)
50
  elif request.HasField('cross_lingual_request'):
51
  logging.info('get cross_lingual inference request')
52
  prompt_speech_16k = torch.from_numpy(np.array(np.frombuffer(request.cross_lingual_request.prompt_audio, dtype=np.int16))).unsqueeze(dim=0)
 
54
  model_output = self.cosyvoice.inference_cross_lingual(request.cross_lingual_request.tts_text, prompt_speech_16k)
55
  else:
56
  logging.info('get instruct inference request')
57
+ model_output = self.cosyvoice.inference_instruct(request.instruct_request.tts_text,
58
+ request.instruct_request.spk_id,
59
+ request.instruct_request.instruct_text)
60
 
61
  logging.info('send inference response')
62
  for i in model_output:
 
64
  response.tts_audio = (i['tts_speech'].numpy() * (2 ** 15)).astype(np.int16).tobytes()
65
  yield response
66
 
67
+
68
  def main():
69
  grpcServer = grpc.server(futures.ThreadPoolExecutor(max_workers=args.max_conc), maximum_concurrent_rpcs=args.max_conc)
70
  cosyvoice_pb2_grpc.add_CosyVoiceServicer_to_server(CosyVoiceServiceImpl(args), grpcServer)
tools/extract_embedding.py CHANGED
@@ -59,6 +59,7 @@ def main(args):
59
  torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
  torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
61
 
 
62
  if __name__ == "__main__":
63
  parser = argparse.ArgumentParser()
64
  parser.add_argument('--dir',
 
59
  torch.save(utt2embedding, '{}/utt2embedding.pt'.format(args.dir))
60
  torch.save(spk2embedding, '{}/spk2embedding.pt'.format(args.dir))
61
 
62
+
63
  if __name__ == "__main__":
64
  parser = argparse.ArgumentParser()
65
  parser.add_argument('--dir',
tools/make_parquet_list.py CHANGED
@@ -53,6 +53,7 @@ def job(utt_list, parquet_file, utt2parquet_file, spk2parquet_file):
53
  json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
54
  logging.info('spend time {}'.format(time.time() - start_time))
55
 
 
56
  if __name__ == "__main__":
57
  parser = argparse.ArgumentParser()
58
  parser.add_argument('--num_utts_per_parquet',
 
53
  json.dump({k: parquet_file for k in list(set(spk_list))}, f, ensure_ascii=False, indent=2)
54
  logging.info('spend time {}'.format(time.time() - start_time))
55
 
56
+
57
  if __name__ == "__main__":
58
  parser = argparse.ArgumentParser()
59
  parser.add_argument('--num_utts_per_parquet',
webui.py CHANGED
@@ -13,9 +13,6 @@
13
  # limitations under the License.
14
  import os
15
  import sys
16
- ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
17
- sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
18
-
19
  import argparse
20
  import gradio as gr
21
  import numpy as np
@@ -23,9 +20,19 @@ import torch
23
  import torchaudio
24
  import random
25
  import librosa
26
-
 
27
  from cosyvoice.cli.cosyvoice import CosyVoice
28
- from cosyvoice.utils.file_utils import load_wav, speed_change, logging
 
 
 
 
 
 
 
 
 
29
 
30
  def generate_seed():
31
  seed = random.randint(1, 100000000)
@@ -34,13 +41,14 @@ def generate_seed():
34
  "value": seed
35
  }
36
 
 
37
  def set_all_random_seed(seed):
38
  random.seed(seed)
39
  np.random.seed(seed)
40
  torch.manual_seed(seed)
41
  torch.cuda.manual_seed_all(seed)
42
 
43
- max_val = 0.8
44
  def postprocess(speech, top_db=60, hop_length=220, win_length=440):
45
  speech, _ = librosa.effects.trim(
46
  speech, top_db=top_db,
@@ -52,16 +60,13 @@ def postprocess(speech, top_db=60, hop_length=220, win_length=440):
52
  speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
53
  return speech
54
 
55
- inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
56
- instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
57
- '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
58
- '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
59
- '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
60
- stream_mode_list = [('否', False), ('是', True)]
61
  def change_instruction(mode_checkbox_group):
62
  return instruct_dict[mode_checkbox_group]
63
 
64
- def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor):
 
 
65
  if prompt_wav_upload is not None:
66
  prompt_wav = prompt_wav_upload
67
  elif prompt_wav_record is not None:
@@ -72,31 +77,31 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
72
  if mode_checkbox_group in ['自然语言控制']:
73
  if cosyvoice.frontend.instruct is False:
74
  gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
75
- return (target_sr, default_data)
76
  if instruct_text == '':
77
  gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
78
- return (target_sr, default_data)
79
  if prompt_wav is not None or prompt_text != '':
80
  gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
81
  # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
82
  if mode_checkbox_group in ['跨语种复刻']:
83
  if cosyvoice.frontend.instruct is True:
84
  gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
85
- return (target_sr, default_data)
86
  if instruct_text != '':
87
  gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
88
  if prompt_wav is None:
89
  gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
90
- return (target_sr, default_data)
91
  gr.Info('您正在使用跨语种复刻模式, 请确保合成文本��prompt文本为不同语言')
92
  # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
93
  if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
94
  if prompt_wav is None:
95
  gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
96
- return (target_sr, default_data)
97
  if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
98
  gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
99
- return (target_sr, default_data)
100
  # sft mode only use sft_dropdown
101
  if mode_checkbox_group in ['预训练音色']:
102
  if instruct_text != '' or prompt_wav is not None or prompt_text != '':
@@ -105,7 +110,7 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
105
  if mode_checkbox_group in ['3s极速复刻']:
106
  if prompt_text == '':
107
  gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
108
- return (target_sr, default_data)
109
  if instruct_text != '':
110
  gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
111
 
@@ -113,28 +118,32 @@ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, pro
113
  logging.info('get sft inference request')
114
  set_all_random_seed(seed)
115
  for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
116
- yield (target_sr, i['tts_speech'].numpy().flatten())
117
  elif mode_checkbox_group == '3s极速复刻':
118
  logging.info('get zero_shot inference request')
119
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
120
  set_all_random_seed(seed)
121
  for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
122
- yield (target_sr, i['tts_speech'].numpy().flatten())
123
  elif mode_checkbox_group == '跨语种复刻':
124
  logging.info('get cross_lingual inference request')
125
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
126
  set_all_random_seed(seed)
127
  for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
128
- yield (target_sr, i['tts_speech'].numpy().flatten())
129
  else:
130
  logging.info('get instruct inference request')
131
  set_all_random_seed(seed)
132
  for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
133
- yield (target_sr, i['tts_speech'].numpy().flatten())
 
134
 
135
  def main():
136
  with gr.Blocks() as demo:
137
- gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
 
 
 
138
  gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
139
 
140
  tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
@@ -160,12 +169,14 @@ def main():
160
 
161
  seed_button.click(generate_seed, inputs=[], outputs=seed)
162
  generate_button.click(generate_audio,
163
- inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text, seed, stream, speed_factor],
 
164
  outputs=[audio_output])
165
  mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
166
  demo.queue(max_size=4, default_concurrency_limit=2)
167
  demo.launch(server_name='0.0.0.0', server_port=args.port)
168
 
 
169
  if __name__ == '__main__':
170
  parser = argparse.ArgumentParser()
171
  parser.add_argument('--port',
 
13
  # limitations under the License.
14
  import os
15
  import sys
 
 
 
16
  import argparse
17
  import gradio as gr
18
  import numpy as np
 
20
  import torchaudio
21
  import random
22
  import librosa
23
+ ROOT_DIR = os.path.dirname(os.path.abspath(__file__))
24
+ sys.path.append('{}/third_party/Matcha-TTS'.format(ROOT_DIR))
25
  from cosyvoice.cli.cosyvoice import CosyVoice
26
+ from cosyvoice.utils.file_utils import load_wav, logging
27
+
28
+ inference_mode_list = ['预训练音色', '3s极速复刻', '跨语种复刻', '自然语言控制']
29
+ instruct_dict = {'预训练音色': '1. 选择预训练音色\n2. 点击生成音频按钮',
30
+ '3s极速复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 输入prompt文本\n3. 点击生成音频按钮',
31
+ '跨语种复刻': '1. 选择prompt音频文件,或录入prompt音频,注意不超过30s,若同时提供,优先选择prompt音频文件\n2. 点击生成音频按钮',
32
+ '自然语言控制': '1. 选择预训练音色\n2. 输入instruct文本\n3. 点击生成音频按钮'}
33
+ stream_mode_list = [('否', False), ('是', True)]
34
+ max_val = 0.8
35
+
36
 
37
  def generate_seed():
38
  seed = random.randint(1, 100000000)
 
41
  "value": seed
42
  }
43
 
44
+
45
  def set_all_random_seed(seed):
46
  random.seed(seed)
47
  np.random.seed(seed)
48
  torch.manual_seed(seed)
49
  torch.cuda.manual_seed_all(seed)
50
 
51
+
52
  def postprocess(speech, top_db=60, hop_length=220, win_length=440):
53
  speech, _ = librosa.effects.trim(
54
  speech, top_db=top_db,
 
60
  speech = torch.concat([speech, torch.zeros(1, int(target_sr * 0.2))], dim=1)
61
  return speech
62
 
63
+
 
 
 
 
 
64
  def change_instruction(mode_checkbox_group):
65
  return instruct_dict[mode_checkbox_group]
66
 
67
+
68
+ def generate_audio(tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
69
+ seed, stream, speed_factor):
70
  if prompt_wav_upload is not None:
71
  prompt_wav = prompt_wav_upload
72
  elif prompt_wav_record is not None:
 
77
  if mode_checkbox_group in ['自然语言控制']:
78
  if cosyvoice.frontend.instruct is False:
79
  gr.Warning('您正在使用自然语言控制模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M-Instruct模型'.format(args.model_dir))
80
+ yield (target_sr, default_data)
81
  if instruct_text == '':
82
  gr.Warning('您正在使用自然语言控制模式, 请输入instruct文本')
83
+ yield (target_sr, default_data)
84
  if prompt_wav is not None or prompt_text != '':
85
  gr.Info('您正在使用自然语言控制模式, prompt音频/prompt文本会被忽略')
86
  # if cross_lingual mode, please make sure that model is iic/CosyVoice-300M and tts_text prompt_text are different language
87
  if mode_checkbox_group in ['跨语种复刻']:
88
  if cosyvoice.frontend.instruct is True:
89
  gr.Warning('您正在使用跨语种复刻模式, {}模型不支持此模式, 请使用iic/CosyVoice-300M模型'.format(args.model_dir))
90
+ yield (target_sr, default_data)
91
  if instruct_text != '':
92
  gr.Info('您正在使用跨语种复刻模式, instruct文本会被忽略')
93
  if prompt_wav is None:
94
  gr.Warning('您正在使用跨语种复刻模式, 请提供prompt音频')
95
+ yield (target_sr, default_data)
96
  gr.Info('您正在使用跨语种复刻模式, 请确保合成文本��prompt文本为不同语言')
97
  # if in zero_shot cross_lingual, please make sure that prompt_text and prompt_wav meets requirements
98
  if mode_checkbox_group in ['3s极速复刻', '跨语种复刻']:
99
  if prompt_wav is None:
100
  gr.Warning('prompt音频为空,您是否忘记输入prompt音频?')
101
+ yield (target_sr, default_data)
102
  if torchaudio.info(prompt_wav).sample_rate < prompt_sr:
103
  gr.Warning('prompt音频采样率{}低于{}'.format(torchaudio.info(prompt_wav).sample_rate, prompt_sr))
104
+ yield (target_sr, default_data)
105
  # sft mode only use sft_dropdown
106
  if mode_checkbox_group in ['预训练音色']:
107
  if instruct_text != '' or prompt_wav is not None or prompt_text != '':
 
110
  if mode_checkbox_group in ['3s极速复刻']:
111
  if prompt_text == '':
112
  gr.Warning('prompt文本为空,您是否忘记输入prompt文本?')
113
+ yield (target_sr, default_data)
114
  if instruct_text != '':
115
  gr.Info('您正在使用3s极速复刻模式,预训练音色/instruct文本会被忽略!')
116
 
 
118
  logging.info('get sft inference request')
119
  set_all_random_seed(seed)
120
  for i in cosyvoice.inference_sft(tts_text, sft_dropdown, stream=stream):
121
+ yield (target_sr, i['tts_speech'].numpy().flatten())
122
  elif mode_checkbox_group == '3s极速复刻':
123
  logging.info('get zero_shot inference request')
124
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
125
  set_all_random_seed(seed)
126
  for i in cosyvoice.inference_zero_shot(tts_text, prompt_text, prompt_speech_16k, stream=stream):
127
+ yield (target_sr, i['tts_speech'].numpy().flatten())
128
  elif mode_checkbox_group == '跨语种复刻':
129
  logging.info('get cross_lingual inference request')
130
  prompt_speech_16k = postprocess(load_wav(prompt_wav, prompt_sr))
131
  set_all_random_seed(seed)
132
  for i in cosyvoice.inference_cross_lingual(tts_text, prompt_speech_16k, stream=stream):
133
+ yield (target_sr, i['tts_speech'].numpy().flatten())
134
  else:
135
  logging.info('get instruct inference request')
136
  set_all_random_seed(seed)
137
  for i in cosyvoice.inference_instruct(tts_text, sft_dropdown, instruct_text, stream=stream):
138
+ yield (target_sr, i['tts_speech'].numpy().flatten())
139
+
140
 
141
  def main():
142
  with gr.Blocks() as demo:
143
+ gr.Markdown("### 代码库 [CosyVoice](https://github.com/FunAudioLLM/CosyVoice) \
144
+ 预训练模型 [CosyVoice-300M](https://www.modelscope.cn/models/iic/CosyVoice-300M) \
145
+ [CosyVoice-300M-Instruct](https://www.modelscope.cn/models/iic/CosyVoice-300M-Instruct) \
146
+ [CosyVoice-300M-SFT](https://www.modelscope.cn/models/iic/CosyVoice-300M-SFT)")
147
  gr.Markdown("#### 请输入需要合成的文本,选择推理模式,并按照提示步骤进行操作")
148
 
149
  tts_text = gr.Textbox(label="输入合成文本", lines=1, value="我是通义实验室语音团队全新推出的生成式语音大模型,提供舒适自然的语音合成能力。")
 
169
 
170
  seed_button.click(generate_seed, inputs=[], outputs=seed)
171
  generate_button.click(generate_audio,
172
+ inputs=[tts_text, mode_checkbox_group, sft_dropdown, prompt_text, prompt_wav_upload, prompt_wav_record, instruct_text,
173
+ seed, stream, speed_factor],
174
  outputs=[audio_output])
175
  mode_checkbox_group.change(fn=change_instruction, inputs=[mode_checkbox_group], outputs=[instruction_text])
176
  demo.queue(max_size=4, default_concurrency_limit=2)
177
  demo.launch(server_name='0.0.0.0', server_port=args.port)
178
 
179
+
180
  if __name__ == '__main__':
181
  parser = argparse.ArgumentParser()
182
  parser.add_argument('--port',