CosyVoice commited on
Commit
a13411c
1 Parent(s): 2895d99

add stream code

Browse files
cosyvoice/cli/cosyvoice.py CHANGED
@@ -12,11 +12,12 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import os
15
- import torch
16
  from hyperpyyaml import load_hyperpyyaml
17
  from modelscope import snapshot_download
18
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
  from cosyvoice.cli.model import CosyVoiceModel
 
20
 
21
  class CosyVoice:
22
 
@@ -44,40 +45,48 @@ class CosyVoice:
44
  spks = list(self.frontend.spk2info.keys())
45
  return spks
46
 
47
- def inference_sft(self, tts_text, spk_id):
48
- tts_speeches = []
49
  for i in self.frontend.text_normalize(tts_text, split=True):
50
  model_input = self.frontend.frontend_sft(i, spk_id)
51
- model_output = self.model.inference(**model_input)
52
- tts_speeches.append(model_output['tts_speech'])
53
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
54
 
55
- def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k):
 
56
  prompt_text = self.frontend.text_normalize(prompt_text, split=False)
57
- tts_speeches = []
58
  for i in self.frontend.text_normalize(tts_text, split=True):
59
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
60
- model_output = self.model.inference(**model_input)
61
- tts_speeches.append(model_output['tts_speech'])
62
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
63
 
64
- def inference_cross_lingual(self, tts_text, prompt_speech_16k):
65
  if self.frontend.instruct is True:
66
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
67
- tts_speeches = []
68
  for i in self.frontend.text_normalize(tts_text, split=True):
69
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
70
- model_output = self.model.inference(**model_input)
71
- tts_speeches.append(model_output['tts_speech'])
72
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
73
 
74
- def inference_instruct(self, tts_text, spk_id, instruct_text):
75
  if self.frontend.instruct is False:
76
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
 
77
  instruct_text = self.frontend.text_normalize(instruct_text, split=False)
78
- tts_speeches = []
79
  for i in self.frontend.text_normalize(tts_text, split=True):
80
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
81
- model_output = self.model.inference(**model_input)
82
- tts_speeches.append(model_output['tts_speech'])
83
- return {'tts_speech': torch.concat(tts_speeches, dim=1)}
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import os
15
+ import time
16
  from hyperpyyaml import load_hyperpyyaml
17
  from modelscope import snapshot_download
18
  from cosyvoice.cli.frontend import CosyVoiceFrontEnd
19
  from cosyvoice.cli.model import CosyVoiceModel
20
+ from cosyvoice.utils.file_utils import logging
21
 
22
  class CosyVoice:
23
 
 
45
  spks = list(self.frontend.spk2info.keys())
46
  return spks
47
 
48
+ def inference_sft(self, tts_text, spk_id, stream=False):
49
+ start_time = time.time()
50
  for i in self.frontend.text_normalize(tts_text, split=True):
51
  model_input = self.frontend.frontend_sft(i, spk_id)
52
+ for model_output in self.model.inference(**model_input, stream=stream):
53
+ speech_len = model_output['tts_speech'].shape[1] / 22050
54
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
55
+ yield model_output
56
+ start_time = time.time()
57
 
58
+ def inference_zero_shot(self, tts_text, prompt_text, prompt_speech_16k, stream=False):
59
+ start_time = time.time()
60
  prompt_text = self.frontend.text_normalize(prompt_text, split=False)
 
61
  for i in self.frontend.text_normalize(tts_text, split=True):
62
  model_input = self.frontend.frontend_zero_shot(i, prompt_text, prompt_speech_16k)
63
+ for model_output in self.model.inference(**model_input, stream=stream):
64
+ speech_len = model_output['tts_speech'].shape[1] / 22050
65
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
66
+ yield model_output
67
+ start_time = time.time()
68
 
69
+ def inference_cross_lingual(self, tts_text, prompt_speech_16k, stream=False):
70
  if self.frontend.instruct is True:
71
  raise ValueError('{} do not support cross_lingual inference'.format(self.model_dir))
72
+ start_time = time.time()
73
  for i in self.frontend.text_normalize(tts_text, split=True):
74
  model_input = self.frontend.frontend_cross_lingual(i, prompt_speech_16k)
75
+ for model_output in self.model.inference(**model_input, stream=stream):
76
+ speech_len = model_output['tts_speech'].shape[1] / 22050
77
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
78
+ yield model_output
79
+ start_time = time.time()
80
 
81
+ def inference_instruct(self, tts_text, spk_id, instruct_text, stream=False):
82
  if self.frontend.instruct is False:
83
  raise ValueError('{} do not support instruct inference'.format(self.model_dir))
84
+ start_time = time.time()
85
  instruct_text = self.frontend.text_normalize(instruct_text, split=False)
 
86
  for i in self.frontend.text_normalize(tts_text, split=True):
87
  model_input = self.frontend.frontend_instruct(i, spk_id, instruct_text)
88
+ for model_output in self.model.inference(**model_input, stream=stream):
89
+ speech_len = model_output['tts_speech'].shape[1] / 22050
90
+ logging.info('yield speech len {}, rtf {}'.format(speech_len, (time.time() - start_time) / speech_len))
91
+ yield model_output
92
+ start_time = time.time()
cosyvoice/cli/model.py CHANGED
@@ -12,6 +12,8 @@
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import torch
 
 
15
 
16
  class CosyVoiceModel:
17
 
@@ -23,6 +25,10 @@ class CosyVoiceModel:
23
  self.llm = llm
24
  self.flow = flow
25
  self.hift = hift
 
 
 
 
26
 
27
  def load(self, llm_model, flow_model, hift_model):
28
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
@@ -36,25 +42,79 @@ class CosyVoiceModel:
36
  prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
37
  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
38
  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
39
- prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32)):
40
- tts_speech_token = self.llm.inference(text=text.to(self.device),
41
- text_len=text_len.to(self.device),
42
- prompt_text=prompt_text.to(self.device),
43
- prompt_text_len=prompt_text_len.to(self.device),
44
- prompt_speech_token=llm_prompt_speech_token.to(self.device),
45
- prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
46
- embedding=llm_embedding.to(self.device),
47
- beam_size=1,
48
- sampling=25,
49
- max_token_text_ratio=30,
50
- min_token_text_ratio=3)
51
- tts_mel = self.flow.inference(token=tts_speech_token,
52
- token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
53
- prompt_token=flow_prompt_speech_token.to(self.device),
54
- prompt_token_len=flow_prompt_speech_token_len.to(self.device),
55
- prompt_feat=prompt_speech_feat.to(self.device),
56
- prompt_feat_len=prompt_speech_feat_len.to(self.device),
57
- embedding=flow_embedding.to(self.device))
58
- tts_speech = self.hift.inference(mel=tts_mel).cpu()
59
- torch.cuda.empty_cache()
60
- return {'tts_speech': tts_speech}
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  # See the License for the specific language governing permissions and
13
  # limitations under the License.
14
  import torch
15
+ import numpy as np
16
+
17
 
18
  class CosyVoiceModel:
19
 
 
25
  self.llm = llm
26
  self.flow = flow
27
  self.hift = hift
28
+ self.stream_win_len = 60
29
+ self.stream_hop_len = 50
30
+ self.overlap = 4395 # 10 token equals 4395 sample point
31
+ self.window = np.hamming(2 * self.overlap)
32
 
33
  def load(self, llm_model, flow_model, hift_model):
34
  self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
 
42
  prompt_text=torch.zeros(1, 0, dtype=torch.int32), prompt_text_len=torch.zeros(1, dtype=torch.int32),
43
  llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), llm_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
44
  flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32), flow_prompt_speech_token_len=torch.zeros(1, dtype=torch.int32),
45
+ prompt_speech_feat=torch.zeros(1, 0, 80), prompt_speech_feat_len=torch.zeros(1, dtype=torch.int32), stream=False):
46
+ if stream is True:
47
+ tts_speech_token, cache_speech = [], None
48
+ for i in self.llm.inference(text=text.to(self.device),
49
+ text_len=text_len.to(self.device),
50
+ prompt_text=prompt_text.to(self.device),
51
+ prompt_text_len=prompt_text_len.to(self.device),
52
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
53
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
54
+ embedding=llm_embedding.to(self.device),
55
+ beam_size=1,
56
+ sampling=25,
57
+ max_token_text_ratio=30,
58
+ min_token_text_ratio=3,
59
+ stream=stream):
60
+ tts_speech_token.append(i)
61
+ if len(tts_speech_token) == self.stream_win_len:
62
+ this_tts_speech_token = torch.concat(tts_speech_token, dim=1)
63
+ this_tts_mel = self.flow.inference(token=this_tts_speech_token,
64
+ token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
65
+ prompt_token=flow_prompt_speech_token.to(self.device),
66
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
67
+ prompt_feat=prompt_speech_feat.to(self.device),
68
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
69
+ embedding=flow_embedding.to(self.device))
70
+ this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
71
+ # fade in/out if necessary
72
+ if cache_speech is not None:
73
+ this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
74
+ yield {'tts_speech': this_tts_speech[:, :-self.overlap]}
75
+ cache_speech = this_tts_speech[:, -self.overlap:]
76
+ tts_speech_token = tts_speech_token[-(self.stream_win_len - self.stream_hop_len):]
77
+ # deal with remain tokens
78
+ if cache_speech is None or len(tts_speech_token) > self.stream_win_len - self.stream_hop_len:
79
+ this_tts_speech_token = torch.concat(tts_speech_token, dim=1)
80
+ this_tts_mel = self.flow.inference(token=this_tts_speech_token,
81
+ token_len=torch.tensor([this_tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
82
+ prompt_token=flow_prompt_speech_token.to(self.device),
83
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
84
+ prompt_feat=prompt_speech_feat.to(self.device),
85
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
86
+ embedding=flow_embedding.to(self.device))
87
+ this_tts_speech = self.hift.inference(mel=this_tts_mel).cpu()
88
+ if cache_speech is not None:
89
+ this_tts_speech[:, :self.overlap] = this_tts_speech[:, :self.overlap] * self.window[:self.overlap] + cache_speech * self.window[-self.overlap:]
90
+ yield {'tts_speech': this_tts_speech}
91
+ else:
92
+ assert len(tts_speech_token) == self.stream_win_len - self.stream_hop_len, 'tts_speech_token not equal to {}'.format(self.stream_win_len - self.stream_hop_len)
93
+ yield {'tts_speech': cache_speech}
94
+ else:
95
+ tts_speech_token = []
96
+ for i in self.llm.inference(text=text.to(self.device),
97
+ text_len=text_len.to(self.device),
98
+ prompt_text=prompt_text.to(self.device),
99
+ prompt_text_len=prompt_text_len.to(self.device),
100
+ prompt_speech_token=llm_prompt_speech_token.to(self.device),
101
+ prompt_speech_token_len=llm_prompt_speech_token_len.to(self.device),
102
+ embedding=llm_embedding.to(self.device),
103
+ beam_size=1,
104
+ sampling=25,
105
+ max_token_text_ratio=30,
106
+ min_token_text_ratio=3,
107
+ stream=stream):
108
+ tts_speech_token.append(i)
109
+ assert len(tts_speech_token) == 1, 'tts_speech_token len should be 1 when stream is {}'.format(stream)
110
+ tts_speech_token = torch.concat(tts_speech_token, dim=1)
111
+ tts_mel = self.flow.inference(token=tts_speech_token,
112
+ token_len=torch.tensor([tts_speech_token.size(1)], dtype=torch.int32).to(self.device),
113
+ prompt_token=flow_prompt_speech_token.to(self.device),
114
+ prompt_token_len=flow_prompt_speech_token_len.to(self.device),
115
+ prompt_feat=prompt_speech_feat.to(self.device),
116
+ prompt_feat_len=prompt_speech_feat_len.to(self.device),
117
+ embedding=flow_embedding.to(self.device))
118
+ tts_speech = self.hift.inference(mel=tts_mel).cpu()
119
+ torch.cuda.empty_cache()
120
+ yield {'tts_speech': tts_speech}
cosyvoice/llm/llm.py CHANGED
@@ -158,6 +158,7 @@ class TransformerLM(torch.nn.Module):
158
  sampling: int = 25,
159
  max_token_text_ratio: float = 20,
160
  min_token_text_ratio: float = 2,
 
161
  ) -> torch.Tensor:
162
  device = text.device
163
  text = torch.concat([prompt_text, text], dim=1)
@@ -199,8 +200,13 @@ class TransformerLM(torch.nn.Module):
199
  top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
200
  if top_ids == self.speech_token_size:
201
  break
 
 
 
202
  out_tokens.append(top_ids)
203
  offset += lm_input.size(1)
204
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
205
 
206
- return torch.tensor([out_tokens], dtype=torch.int64, device=device)
 
 
 
158
  sampling: int = 25,
159
  max_token_text_ratio: float = 20,
160
  min_token_text_ratio: float = 2,
161
+ stream: bool = False,
162
  ) -> torch.Tensor:
163
  device = text.device
164
  text = torch.concat([prompt_text, text], dim=1)
 
200
  top_ids = self.sampling_ids(logp.squeeze(dim=0), sampling, beam_size, ignore_eos=True if i < min_len else False).item()
201
  if top_ids == self.speech_token_size:
202
  break
203
+ # in stream mode, yield token one by one
204
+ if stream is True:
205
+ yield torch.tensor([[top_ids]], dtype=torch.int64, device=device)
206
  out_tokens.append(top_ids)
207
  offset += lm_input.size(1)
208
  lm_input = self.speech_embedding.weight[top_ids].reshape(1, 1, -1)
209
 
210
+ # in non-stream mode, yield all token
211
+ if stream is False:
212
+ yield torch.tensor([out_tokens], dtype=torch.int64, device=device)
cosyvoice/utils/file_utils.py CHANGED
@@ -15,6 +15,10 @@
15
 
16
  import json
17
  import torchaudio
 
 
 
 
18
 
19
 
20
  def read_lists(list_file):
 
15
 
16
  import json
17
  import torchaudio
18
+ import logging
19
+ logging.getLogger('matplotlib').setLevel(logging.WARNING)
20
+ logging.basicConfig(level=logging.DEBUG,
21
+ format='%(asctime)s %(levelname)s %(message)s')
22
 
23
 
24
  def read_lists(list_file):