Manmay commited on
Commit
7f2cd3b
1 Parent(s): d99cbfa

Update tortoise/api.py

Browse files
Files changed (1) hide show
  1. tortoise/api.py +25 -6
tortoise/api.py CHANGED
@@ -277,22 +277,41 @@ class TextToSpeech:
277
  settings.update(kwargs) # allow overriding of preset settings with kwargs
278
  for audio_frame in self.tts(text, **settings):
279
  yield audio_frame
280
-
281
- def handle_chunks(self, wav_gen, wav_gen_prev, wav_overlap, overlap_len):
282
- """Handle chunk formatting in streaming mode"""
 
 
 
 
 
 
 
 
283
  wav_chunk = wav_gen[:-overlap_len]
 
 
284
  if wav_gen_prev is not None:
285
  wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
 
 
286
  if wav_overlap is not None:
287
- crossfade_wav = wav_chunk[:overlap_len]
288
- crossfade_wav = crossfade_wav * torch.linspace(0.0, 1.0, overlap_len).to(crossfade_wav.device)
289
- wav_chunk[:overlap_len] = wav_overlap * torch.linspace(1.0, 0.0, overlap_len).to(wav_overlap.device)
 
290
  wav_chunk[:overlap_len] += crossfade_wav
 
 
291
  wav_overlap = wav_gen[-overlap_len:]
 
 
292
  wav_gen_prev = wav_gen
 
293
  return wav_chunk, wav_gen_prev, wav_overlap
294
 
295
 
 
296
  def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
297
  return_deterministic_state=False, overlap_wav_len=1024, stream_chunk_size=40,
298
  # autoregressive generation parameters follow
 
277
  settings.update(kwargs) # allow overriding of preset settings with kwargs
278
  for audio_frame in self.tts(text, **settings):
279
  yield audio_frame
280
+ def handle_chunks(
281
+ self,
282
+ wav_gen: torch.Tensor,
283
+ wav_gen_prev: torch.Tensor,
284
+ wav_overlap: torch.Tensor,
285
+ overlap_len: int
286
+ ) -> tuple:
287
+ """
288
+ Handle chunk formatting in streaming mode.
289
+ """
290
+ # Extract the current chunk without overlap
291
  wav_chunk = wav_gen[:-overlap_len]
292
+
293
+ # If there's a previous chunk, extract the portion that's not overlapping
294
  if wav_gen_prev is not None:
295
  wav_chunk = wav_gen[(wav_gen_prev.shape[0] - overlap_len) : -overlap_len]
296
+
297
+ # Perform the crossfade if there is an overlap
298
  if wav_overlap is not None:
299
+ crossfade_window = torch.linspace(0.0, 1.0, overlap_len).to(wav_gen.device)
300
+
301
+ crossfade_wav = wav_chunk[:overlap_len] * crossfade_window
302
+ wav_chunk[:overlap_len] = wav_overlap * (1 - crossfade_window)
303
  wav_chunk[:overlap_len] += crossfade_wav
304
+
305
+ # Save the last part of this chunk for overlapping with the next chunk
306
  wav_overlap = wav_gen[-overlap_len:]
307
+
308
+ # Update wav_gen_prev for the next iteration
309
  wav_gen_prev = wav_gen
310
+
311
  return wav_chunk, wav_gen_prev, wav_overlap
312
 
313
 
314
+
315
  def tts(self, text, voice_samples=None, conditioning_latents=None, k=1, verbose=True, use_deterministic_seed=None,
316
  return_deterministic_state=False, overlap_wav_len=1024, stream_chunk_size=40,
317
  # autoregressive generation parameters follow