kartynnik commited on
Commit
2d9add8
1 Parent(s): 365f5cf

Run on CPU

Browse files
Files changed (1) hide show
  1. app.py +18 -19
app.py CHANGED
@@ -53,7 +53,6 @@ def tts(text,
53
  random_seed):
54
 
55
  torch.manual_seed(random_seed)
56
- torch.cuda.manual_seed(random_seed)
57
  np.random.seed(random_seed)
58
 
59
  text_len = len(text)
@@ -63,12 +62,12 @@ def tts(text,
63
  else:
64
  text = text_to_sequence(str(text), ["english_cleaners2"])
65
 
66
- token = add_blank_token(text).unsqueeze(0).cuda()
67
- token_length = torch.LongTensor([token.size(-1)]).cuda()
68
 
69
  # Prompt load
70
  # sample_rate, audio = prompt
71
- # audio = torch.FloatTensor([audio]).cuda()
72
  # if audio.shape[0] != 1:
73
  # audio = audio[:1,:]
74
  # audio = audio / 32768
@@ -89,28 +88,28 @@ def tts(text,
89
  # If you have a memory issue during denosing the prompt, try to denoise the prompt with cpu before TTS
90
  # We will have a plan to replace a memory-efficient denoiser
91
  if denoise == 0:
92
- audio = torch.cat([audio.cuda(), audio.cuda()], dim=0)
93
  else:
94
  with torch.no_grad():
95
 
96
  if ori_prompt_len > 80000:
97
  denoised_audio = []
98
  for i in range((ori_prompt_len//80000)):
99
- denoised_audio.append(denoise(audio.squeeze(0).cuda()[i*80000:(i+1)*80000], denoiser, hps_denoiser))
100
 
101
- denoised_audio.append(denoise(audio.squeeze(0).cuda()[(i+1)*80000:], denoiser, hps_denoiser))
102
  denoised_audio = torch.cat(denoised_audio, dim=1)
103
  else:
104
- denoised_audio = denoise(audio.squeeze(0).cuda(), denoiser, hps_denoiser)
105
 
106
- audio = torch.cat([audio.cuda(), denoised_audio[:,:audio.shape[-1]]], dim=0)
107
 
108
  audio = audio[:,:ori_prompt_len] # 20231108 We found that large size of padding decreases a performance so we remove the paddings after denosing.
109
 
110
  if audio.shape[-1]<48000:
111
  audio = torch.cat([audio,audio,audio,audio,audio], dim=1)
112
 
113
- src_mel = mel_fn(audio.cuda())
114
 
115
  src_length = torch.LongTensor([src_mel.size(2)]).to(device)
116
  src_length2 = torch.cat([src_length,src_length], dim=0)
@@ -120,9 +119,9 @@ def tts(text,
120
  w2v_x, pitch = text2w2v.infer_noise_control(token, token_length, src_mel, src_length2,
121
  noise_scale=ttv_temperature, noise_scale_w=duratuion_temperature,
122
  length_scale=duratuion_length, denoise_ratio=denoise_ratio)
123
- src_length = torch.LongTensor([w2v_x.size(2)]).cuda()
124
 
125
- pitch[pitch<torch.log(torch.tensor([55]).cuda())] = 0
126
 
127
  ## Hierarchical Speech Synthesizer (W2V, F0 --> 16k Audio)
128
  converted_audio = \
@@ -165,7 +164,7 @@ def main():
165
  a = parser.parse_args()
166
 
167
  global device, hps, hps_t2w2v,h_sr,h_sr48, hps_denoiser
168
- device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
169
 
170
  hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json'))
171
  hps_t2w2v = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_text2w2v)[0], 'config.json'))
@@ -184,27 +183,27 @@ def main():
184
  f_max=hps.data.mel_fmax,
185
  n_mels=hps.data.n_mel_channels,
186
  window_fn=torch.hann_window
187
- ).cuda()
188
 
189
  net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
190
  hps.train.segment_size // hps.data.hop_length,
191
- **hps.model).cuda()
192
  net_g.load_state_dict(torch.load(a.ckpt))
193
  _ = net_g.eval()
194
 
195
  text2w2v = Text2W2V(hps.data.filter_length // 2 + 1,
196
  hps.train.segment_size // hps.data.hop_length,
197
- **hps_t2w2v.model).cuda()
198
  text2w2v.load_state_dict(torch.load(a.ckpt_text2w2v))
199
  text2w2v.eval()
200
 
201
  speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
202
  h_sr48.train.segment_size // h_sr48.data.hop_length,
203
- **h_sr48.model).cuda()
204
  utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
205
  speechsr.eval()
206
 
207
- denoiser = MPNet(hps_denoiser).cuda()
208
  state_dict = load_checkpoint(a.denoiser_ckpt, device)
209
  denoiser.load_state_dict(state_dict['generator'])
210
  denoiser.eval()
@@ -219,7 +218,7 @@ def main():
219
  gr.Slider(0,1,0),
220
  gr.Slider(0,9999,1111)],
221
  outputs = 'audio',
222
- title = 'HierSpeech++',
223
  description = '''<div>
224
  <p style="text-align: left"> HierSpeech++ is a zero-shot speech synthesis model.</p>
225
  <p style="text-align: left"> Our model is trained with LibriTTS dataset so this model only supports english. We will release a multi-lingual HierSpeech++ soon.</p>
 
53
  random_seed):
54
 
55
  torch.manual_seed(random_seed)
 
56
  np.random.seed(random_seed)
57
 
58
  text_len = len(text)
 
62
  else:
63
  text = text_to_sequence(str(text), ["english_cleaners2"])
64
 
65
+ token = add_blank_token(text).unsqueeze(0)
66
+ token_length = torch.LongTensor([token.size(-1)])
67
 
68
  # Prompt load
69
  # sample_rate, audio = prompt
70
+ # audio = torch.FloatTensor([audio])
71
  # if audio.shape[0] != 1:
72
  # audio = audio[:1,:]
73
  # audio = audio / 32768
 
88
  # If you have a memory issue during denosing the prompt, try to denoise the prompt with cpu before TTS
89
  # We will have a plan to replace a memory-efficient denoiser
90
  if denoise == 0:
91
+ audio = torch.cat([audio, audio], dim=0)
92
  else:
93
  with torch.no_grad():
94
 
95
  if ori_prompt_len > 80000:
96
  denoised_audio = []
97
  for i in range((ori_prompt_len//80000)):
98
+ denoised_audio.append(denoise(audio.squeeze(0)[i*80000:(i+1)*80000], denoiser, hps_denoiser))
99
 
100
+ denoised_audio.append(denoise(audio.squeeze(0)[(i+1)*80000:], denoiser, hps_denoiser))
101
  denoised_audio = torch.cat(denoised_audio, dim=1)
102
  else:
103
+ denoised_audio = denoise(audio.squeeze(0), denoiser, hps_denoiser)
104
 
105
+ audio = torch.cat([audio, denoised_audio[:,:audio.shape[-1]]], dim=0)
106
 
107
  audio = audio[:,:ori_prompt_len] # 20231108 We found that large size of padding decreases a performance so we remove the paddings after denosing.
108
 
109
  if audio.shape[-1]<48000:
110
  audio = torch.cat([audio,audio,audio,audio,audio], dim=1)
111
 
112
+ src_mel = mel_fn(audio)
113
 
114
  src_length = torch.LongTensor([src_mel.size(2)]).to(device)
115
  src_length2 = torch.cat([src_length,src_length], dim=0)
 
119
  w2v_x, pitch = text2w2v.infer_noise_control(token, token_length, src_mel, src_length2,
120
  noise_scale=ttv_temperature, noise_scale_w=duratuion_temperature,
121
  length_scale=duratuion_length, denoise_ratio=denoise_ratio)
122
+ src_length = torch.LongTensor([w2v_x.size(2)])
123
 
124
+ pitch[pitch<torch.log(torch.tensor([55]))] = 0
125
 
126
  ## Hierarchical Speech Synthesizer (W2V, F0 --> 16k Audio)
127
  converted_audio = \
 
164
  a = parser.parse_args()
165
 
166
  global device, hps, hps_t2w2v,h_sr,h_sr48, hps_denoiser
167
+ device = 'cpu'
168
 
169
  hps = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt)[0], 'config.json'))
170
  hps_t2w2v = utils.get_hparams_from_file(os.path.join(os.path.split(a.ckpt_text2w2v)[0], 'config.json'))
 
183
  f_max=hps.data.mel_fmax,
184
  n_mels=hps.data.n_mel_channels,
185
  window_fn=torch.hann_window
186
+ )
187
 
188
  net_g = SynthesizerTrn(hps.data.filter_length // 2 + 1,
189
  hps.train.segment_size // hps.data.hop_length,
190
+ **hps.model)
191
  net_g.load_state_dict(torch.load(a.ckpt))
192
  _ = net_g.eval()
193
 
194
  text2w2v = Text2W2V(hps.data.filter_length // 2 + 1,
195
  hps.train.segment_size // hps.data.hop_length,
196
+ **hps_t2w2v.model)
197
  text2w2v.load_state_dict(torch.load(a.ckpt_text2w2v))
198
  text2w2v.eval()
199
 
200
  speechsr = SpeechSR48(h_sr48.data.n_mel_channels,
201
  h_sr48.train.segment_size // h_sr48.data.hop_length,
202
+ **h_sr48.model)
203
  utils.load_checkpoint(a.ckpt_sr48, speechsr, None)
204
  speechsr.eval()
205
 
206
+ denoiser = MPNet(hps_denoiser)
207
  state_dict = load_checkpoint(a.denoiser_ckpt, device)
208
  denoiser.load_state_dict(state_dict['generator'])
209
  denoiser.eval()
 
218
  gr.Slider(0,1,0),
219
  gr.Slider(0,9999,1111)],
220
  outputs = 'audio',
221
+ title = 'HierSpeech++ (CPU)',
222
  description = '''<div>
223
  <p style="text-align: left"> HierSpeech++ is a zero-shot speech synthesis model.</p>
224
  <p style="text-align: left"> Our model is trained with LibriTTS dataset so this model only supports english. We will release a multi-lingual HierSpeech++ soon.</p>