CosyVoice commited on
Commit
72b89a5
·
1 Parent(s): 49015f6

update vc/tts code

Browse files
cosyvoice/cli/model.py CHANGED
@@ -35,7 +35,7 @@ class CosyVoiceModel:
35
  self.token_max_hop_len = 200
36
  self.token_overlap_len = 20
37
  # mel fade in out
38
- self.mel_overlap_len = 34
39
  self.mel_window = np.hamming(2 * self.mel_overlap_len)
40
  # hift cache
41
  self.mel_cache_len = 20
@@ -54,9 +54,10 @@ class CosyVoiceModel:
54
  self.hift_cache_dict = {}
55
 
56
  def load(self, llm_model, flow_model, hift_model):
57
- self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
58
- self.llm.to(self.device).eval()
59
- self.llm.half()
 
60
  self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
61
  self.flow.to(self.device).eval()
62
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
@@ -131,11 +132,11 @@ class CosyVoiceModel:
131
  tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
132
  return tts_speech
133
 
134
- def inference(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
135
- prompt_text=torch.zeros(1, 0, dtype=torch.int32),
136
- llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
137
- flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
138
- prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
139
  # this_uuid is used to track variables related to this inference thread
140
  this_uuid = str(uuid.uuid1())
141
  with self.lock:
@@ -148,7 +149,8 @@ class CosyVoiceModel:
148
  while True:
149
  time.sleep(0.1)
150
  if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
151
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len], dim=1)
 
152
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
153
  prompt_token=flow_prompt_speech_token,
154
  prompt_feat=prompt_speech_feat,
@@ -164,7 +166,7 @@ class CosyVoiceModel:
164
  break
165
  p.join()
166
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
167
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
168
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
169
  prompt_token=flow_prompt_speech_token,
170
  prompt_feat=prompt_speech_feat,
@@ -175,7 +177,58 @@ class CosyVoiceModel:
175
  else:
176
  # deal with all tokens
177
  p.join()
178
- this_tts_speech_token = torch.concat(self.tts_speech_token_dict[this_uuid], dim=1)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
179
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
180
  prompt_token=flow_prompt_speech_token,
181
  prompt_feat=prompt_speech_feat,
 
35
  self.token_max_hop_len = 200
36
  self.token_overlap_len = 20
37
  # mel fade in out
38
+ self.mel_overlap_len = int(self.token_overlap_len / self.flow.input_frame_rate * 22050 / 256)
39
  self.mel_window = np.hamming(2 * self.mel_overlap_len)
40
  # hift cache
41
  self.mel_cache_len = 20
 
54
  self.hift_cache_dict = {}
55
 
56
  def load(self, llm_model, flow_model, hift_model):
57
+ if self.llm is not None:
58
+ self.llm.load_state_dict(torch.load(llm_model, map_location=self.device))
59
+ self.llm.to(self.device).eval()
60
+ self.llm.half()
61
  self.flow.load_state_dict(torch.load(flow_model, map_location=self.device))
62
  self.flow.to(self.device).eval()
63
  self.hift.load_state_dict(torch.load(hift_model, map_location=self.device))
 
132
  tts_speech = fade_in_out(tts_speech, self.hift_cache_dict[uuid]['speech'], self.speech_window)
133
  return tts_speech
134
 
135
+ def tts(self, text, flow_embedding, llm_embedding=torch.zeros(0, 192),
136
+ prompt_text=torch.zeros(1, 0, dtype=torch.int32),
137
+ llm_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
138
+ flow_prompt_speech_token=torch.zeros(1, 0, dtype=torch.int32),
139
+ prompt_speech_feat=torch.zeros(1, 0, 80), stream=False, speed=1.0, **kwargs):
140
  # this_uuid is used to track variables related to this inference thread
141
  this_uuid = str(uuid.uuid1())
142
  with self.lock:
 
149
  while True:
150
  time.sleep(0.1)
151
  if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
152
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
153
+ .unsqueeze(dim=0)
154
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
155
  prompt_token=flow_prompt_speech_token,
156
  prompt_feat=prompt_speech_feat,
 
166
  break
167
  p.join()
168
  # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
169
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
170
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
171
  prompt_token=flow_prompt_speech_token,
172
  prompt_feat=prompt_speech_feat,
 
177
  else:
178
  # deal with all tokens
179
  p.join()
180
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
181
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
182
+ prompt_token=flow_prompt_speech_token,
183
+ prompt_feat=prompt_speech_feat,
184
+ embedding=flow_embedding,
185
+ uuid=this_uuid,
186
+ finalize=True,
187
+ speed=speed)
188
+ yield {'tts_speech': this_tts_speech.cpu()}
189
+ with self.lock:
190
+ self.tts_speech_token_dict.pop(this_uuid)
191
+ self.llm_end_dict.pop(this_uuid)
192
+ self.mel_overlap_dict.pop(this_uuid)
193
+ self.hift_cache_dict.pop(this_uuid)
194
+
195
+ def vc(self, source_speech_token, flow_prompt_speech_token, prompt_speech_feat, flow_embedding, stream=False, speed=1.0, **kwargs):
196
+ # this_uuid is used to track variables related to this inference thread
197
+ this_uuid = str(uuid.uuid1())
198
+ with self.lock:
199
+ self.tts_speech_token_dict[this_uuid], self.llm_end_dict[this_uuid] = source_speech_token.flatten().tolist(), True
200
+ self.mel_overlap_dict[this_uuid], self.hift_cache_dict[this_uuid] = None, None
201
+ if stream is True:
202
+ token_hop_len = self.token_min_hop_len
203
+ while True:
204
+ if len(self.tts_speech_token_dict[this_uuid]) >= token_hop_len + self.token_overlap_len:
205
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid][:token_hop_len + self.token_overlap_len]) \
206
+ .unsqueeze(dim=0)
207
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
208
+ prompt_token=flow_prompt_speech_token,
209
+ prompt_feat=prompt_speech_feat,
210
+ embedding=flow_embedding,
211
+ uuid=this_uuid,
212
+ finalize=False)
213
+ yield {'tts_speech': this_tts_speech.cpu()}
214
+ with self.lock:
215
+ self.tts_speech_token_dict[this_uuid] = self.tts_speech_token_dict[this_uuid][token_hop_len:]
216
+ # increase token_hop_len for better speech quality
217
+ token_hop_len = min(self.token_max_hop_len, int(token_hop_len * self.stream_scale_factor))
218
+ if self.llm_end_dict[this_uuid] is True and len(self.tts_speech_token_dict[this_uuid]) < token_hop_len + self.token_overlap_len:
219
+ break
220
+ # deal with remain tokens, make sure inference remain token len equals token_hop_len when cache_speech is not None
221
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid], dim=1).unsqueeze(dim=0)
222
+ this_tts_speech = self.token2wav(token=this_tts_speech_token,
223
+ prompt_token=flow_prompt_speech_token,
224
+ prompt_feat=prompt_speech_feat,
225
+ embedding=flow_embedding,
226
+ uuid=this_uuid,
227
+ finalize=True)
228
+ yield {'tts_speech': this_tts_speech.cpu()}
229
+ else:
230
+ # deal with all tokens
231
+ this_tts_speech_token = torch.tensor(self.tts_speech_token_dict[this_uuid]).unsqueeze(dim=0)
232
  this_tts_speech = self.token2wav(token=this_tts_speech_token,
233
  prompt_token=flow_prompt_speech_token,
234
  prompt_feat=prompt_speech_feat,
cosyvoice/flow/flow.py CHANGED
@@ -125,7 +125,7 @@ class MaskedDiffWithXvec(torch.nn.Module):
125
  h, h_lengths = self.encoder(token, token_len)
126
  h = self.encoder_proj(h)
127
  mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
128
- h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2)
129
 
130
  # get conditions
131
  conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
 
125
  h, h_lengths = self.encoder(token, token_len)
126
  h = self.encoder_proj(h)
127
  mel_len1, mel_len2 = prompt_feat.shape[1], int(token_len2 / self.input_frame_rate * 22050 / 256)
128
+ h, h_lengths = self.length_regulator.inference(h[:, :token_len1], h[:, token_len1:], mel_len1, mel_len2, self.input_frame_rate)
129
 
130
  # get conditions
131
  conds = torch.zeros([1, mel_len1 + mel_len2, self.output_size], device=token.device)
cosyvoice/flow/length_regulator.py CHANGED
@@ -49,13 +49,14 @@ class InterpolateRegulator(nn.Module):
49
  olens = ylens
50
  return out * mask, olens
51
 
52
- def inference(self, x1, x2, mel_len1, mel_len2):
53
  # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
  # x in (B, T, D)
55
  if x2.shape[1] > 40:
56
- x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=34, mode='linear')
57
- x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - 34 * 2, mode='linear')
58
- x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=34, mode='linear')
 
59
  x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
60
  else:
61
  x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
 
49
  olens = ylens
50
  return out * mask, olens
51
 
52
+ def inference(self, x1, x2, mel_len1, mel_len2, input_frame_rate=50):
53
  # in inference mode, interploate prompt token and token(head/mid/tail) seprately, so we can get a clear separation point of mel
54
  # x in (B, T, D)
55
  if x2.shape[1] > 40:
56
+ x2_head = F.interpolate(x2[:, :20].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
57
+ x2_mid = F.interpolate(x2[:, 20:-20].transpose(1, 2).contiguous(), size=mel_len2 - int(20 / input_frame_rate * 22050 / 256) * 2,
58
+ mode='linear')
59
+ x2_tail = F.interpolate(x2[:, -20:].transpose(1, 2).contiguous(), size=int(20 / input_frame_rate * 22050 / 256), mode='linear')
60
  x2 = torch.concat([x2_head, x2_mid, x2_tail], dim=2)
61
  else:
62
  x2 = F.interpolate(x2.transpose(1, 2).contiguous(), size=mel_len2, mode='linear')
cosyvoice/tokenizer/tokenizer.py CHANGED
@@ -1,9 +1,7 @@
1
  import base64
2
  import os
3
- import string
4
- from dataclasses import dataclass, field
5
- from functools import cached_property, lru_cache
6
- from typing import Dict, List, Optional, Tuple
7
  from whisper.tokenizer import Tokenizer
8
 
9
  import tiktoken
 
1
  import base64
2
  import os
3
+ from functools import lru_cache
4
+ from typing import Optional
 
 
5
  from whisper.tokenizer import Tokenizer
6
 
7
  import tiktoken
cosyvoice/utils/common.py CHANGED
@@ -145,6 +145,7 @@ def fade_in_out(fade_in_mel, fade_out_mel, window):
145
  fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
146
  return fade_in_mel.to(device)
147
 
 
148
  def set_all_random_seed(seed):
149
  random.seed(seed)
150
  np.random.seed(seed)
 
145
  fade_out_mel[..., -mel_overlap_len:] * window[mel_overlap_len:]
146
  return fade_in_mel.to(device)
147
 
148
+
149
  def set_all_random_seed(seed):
150
  random.seed(seed)
151
  np.random.seed(seed)