jbetker commited on
Commit
9007955
1 Parent(s): cd2d422

Add redaction support

Browse files
tortoise/api.py CHANGED
@@ -19,6 +19,7 @@ from tortoise.models.vocoder import UnivNetGenerator
19
  from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
20
  from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
21
  from tortoise.utils.tokenizer import VoiceBpeTokenizer
 
22
 
23
  pbar = None
24
 
@@ -158,11 +159,23 @@ def classify_audio_clip(clip):
158
  class TextToSpeech:
159
  """
160
  Main entry point into Tortoise.
161
- :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
162
- GPU OOM errors. Larger numbers generates slightly faster.
163
  """
164
- def __init__(self, autoregressive_batch_size=16, models_dir='.models'):
 
 
 
 
 
 
 
 
 
 
165
  self.autoregressive_batch_size = autoregressive_batch_size
 
 
 
 
166
  self.tokenizer = VoiceBpeTokenizer()
167
  download_models()
168
 
@@ -380,7 +393,6 @@ class TextToSpeech:
380
  wav_candidates = []
381
  self.diffusion = self.diffusion.cuda()
382
  self.vocoder = self.vocoder.cuda()
383
- diffusion_conds =
384
  for b in range(best_results.shape[0]):
385
  codes = best_results[b].unsqueeze(0)
386
  latents = best_latents[b].unsqueeze(0)
@@ -403,6 +415,12 @@ class TextToSpeech:
403
  self.diffusion = self.diffusion.cpu()
404
  self.vocoder = self.vocoder.cpu()
405
 
 
 
 
 
 
406
  if len(wav_candidates) > 1:
407
  return wav_candidates
408
  return wav_candidates[0]
 
 
19
  from tortoise.utils.audio import wav_to_univnet_mel, denormalize_tacotron_mel
20
  from tortoise.utils.diffusion import SpacedDiffusion, space_timesteps, get_named_beta_schedule
21
  from tortoise.utils.tokenizer import VoiceBpeTokenizer
22
+ from tortoise.utils.wav2vec_alignment import Wav2VecAlignment
23
 
24
  pbar = None
25
 
 
159
  class TextToSpeech:
160
  """
161
  Main entry point into Tortoise.
 
 
162
  """
163
+
164
+ def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
165
+ """
166
+ Constructor
167
+ :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
168
+ GPU OOM errors. Larger numbers generates slightly faster.
169
+ :param models_dir: Where model weights are stored. This should only be specified if you are providing your own
170
+ models, otherwise use the defaults.
171
+ :param enable_redaction: When true, text enclosed in brackets are automatically redacted from the spoken output
172
+ (but are still rendered by the model). This can be used for prompt engineering.
173
+ """
174
  self.autoregressive_batch_size = autoregressive_batch_size
175
+ self.enable_redaction = enable_redaction
176
+ if self.enable_redaction:
177
+ self.aligner = Wav2VecAlignment()
178
+
179
  self.tokenizer = VoiceBpeTokenizer()
180
  download_models()
181
 
 
393
  wav_candidates = []
394
  self.diffusion = self.diffusion.cuda()
395
  self.vocoder = self.vocoder.cuda()
 
396
  for b in range(best_results.shape[0]):
397
  codes = best_results[b].unsqueeze(0)
398
  latents = best_latents[b].unsqueeze(0)
 
415
  self.diffusion = self.diffusion.cpu()
416
  self.vocoder = self.vocoder.cpu()
417
 
418
+ def potentially_redact(self, clip, text):
419
+ if self.enable_redaction:
420
+ return self.aligner.redact(clip, text)
421
+ return clip
422
+ wav_candidates = [potentially_redact(wav_candidate, text) for wav_candidate in wav_candidates]
423
  if len(wav_candidates) > 1:
424
  return wav_candidates
425
  return wav_candidates[0]
426
+
tortoise/utils/wav2vec_alignment.py ADDED
@@ -0,0 +1,82 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torchaudio
3
+ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor
4
+
5
+ from tortoise.utils.audio import load_audio
6
+
7
+
8
+ class Wav2VecAlignment:
9
+ def __init__(self):
10
+ self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
11
+ self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
12
+ self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
13
+
14
+ def align(self, audio, expected_text, audio_sample_rate=24000, topk=3):
15
+ orig_len = audio.shape[-1]
16
+
17
+ with torch.no_grad():
18
+ self.model = self.model.cuda()
19
+ audio = audio.to('cuda')
20
+ audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
21
+ clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
22
+ logits = self.model(clip_norm).logits
23
+ self.model = self.model.cpu()
24
+
25
+ logits = logits[0]
26
+ w2v_compression = orig_len // logits.shape[0]
27
+ expected_tokens = self.tokenizer.encode(expected_text)
28
+ if len(expected_tokens) == 1:
29
+ return [0] # The alignment is simple; there is only one token.
30
+ expected_tokens.pop(0) # The first token is a given.
31
+ next_expected_token = expected_tokens.pop(0)
32
+ alignments = [0]
33
+ for i, logit in enumerate(logits):
34
+ top = logit.topk(topk).indices.tolist()
35
+ if next_expected_token in top:
36
+ alignments.append(i * w2v_compression)
37
+ if len(expected_tokens) > 0:
38
+ next_expected_token = expected_tokens.pop(0)
39
+ else:
40
+ break
41
+
42
+ if len(expected_tokens) > 0:
43
+ print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:"
44
+ f" {self.tokenizer.decode(expected_tokens)}")
45
+ return None
46
+
47
+ return alignments
48
+
49
+ def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3):
50
+ if '[' not in expected_text:
51
+ return audio
52
+ splitted = expected_text.split('[')
53
+ fully_split = [splitted[0]]
54
+ for spl in splitted[1:]:
55
+ assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
56
+ fully_split.extend(spl.split(']'))
57
+ # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
58
+ non_redacted_intervals = []
59
+ last_point = 0
60
+ for i in range(len(fully_split)):
61
+ if i % 2 == 0:
62
+ non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1))
63
+ last_point += len(fully_split[i])
64
+
65
+ bare_text = ''.join(fully_split)
66
+ alignments = self.align(audio, bare_text, audio_sample_rate, topk)
67
+ if alignments is None:
68
+ return audio # Cannot redact because alignment did not succeed.
69
+
70
+ output_audio = []
71
+ for nri in non_redacted_intervals:
72
+ start, stop = nri
73
+ output_audio.append(audio[:, alignments[start]:alignments[stop]])
74
+ return torch.cat(output_audio, dim=-1)
75
+
76
+
77
+ if __name__ == '__main__':
78
+ some_audio = load_audio('../../results/favorites/morgan_freeman_metallic_hydrogen.mp3', 24000)
79
+ aligner = Wav2VecAlignment()
80
+ text = "instead of molten iron, jupiter [and brown dwaves] have hydrogen, which [is under so much pressure that it] develops metallic properties"
81
+ redact = aligner.redact(some_audio, text)
82
+ torchaudio.save(f'test_output.wav', redact, 24000)