jbetker commited on
Commit
ddb19f6
1 Parent(s): c1d004a

Enable redaction by default

Browse files
tortoise/api.py CHANGED
@@ -165,7 +165,7 @@ class TextToSpeech:
165
  Main entry point into Tortoise.
166
  """
167
 
168
- def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=False):
169
  """
170
  Constructor
171
  :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
@@ -275,7 +275,6 @@ class TextToSpeech:
275
  """
276
  # Use generally found best tuning knobs for generation.
277
  kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
278
- #'typical_sampling': True,
279
  'top_p': .8,
280
  'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
281
  # Presets are defined here.
 
165
  Main entry point into Tortoise.
166
  """
167
 
168
+ def __init__(self, autoregressive_batch_size=16, models_dir='.models', enable_redaction=True):
169
  """
170
  Constructor
171
  :param autoregressive_batch_size: Specifies how many samples to generate per batch. Lower this if you are seeing
 
275
  """
276
  # Use generally found best tuning knobs for generation.
277
  kwargs.update({'temperature': .8, 'length_penalty': 1.0, 'repetition_penalty': 2.0,
 
278
  'top_p': .8,
279
  'cond_free_k': 2.0, 'diffusion_temperature': 1.0})
280
  # Presets are defined here.
tortoise/utils/wav2vec_alignment.py CHANGED
@@ -7,13 +7,52 @@ from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTo
7
  from tortoise.utils.audio import load_audio
8
 
9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
10
  class Wav2VecAlignment:
 
 
 
11
  def __init__(self):
12
  self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
13
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
14
  self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
15
 
16
- def align(self, audio, expected_text, audio_sample_rate=24000, topk=3, return_partial=False):
17
  orig_len = audio.shape[-1]
18
 
19
  with torch.no_grad():
@@ -25,32 +64,59 @@ class Wav2VecAlignment:
25
  self.model = self.model.cpu()
26
 
27
  logits = logits[0]
 
 
 
28
  w2v_compression = orig_len // logits.shape[0]
29
- expected_tokens = self.tokenizer.encode(expected_text)
 
30
  if len(expected_tokens) == 1:
31
  return [0] # The alignment is simple; there is only one token.
32
  expected_tokens.pop(0) # The first token is a given.
33
- next_expected_token = expected_tokens.pop(0)
 
34
  alignments = [0]
 
 
 
 
 
 
 
 
 
 
 
 
 
 
35
  for i, logit in enumerate(logits):
36
- top = logit.topk(topk).indices.tolist()
37
- if next_expected_token in top:
38
  alignments.append(i * w2v_compression)
39
  if len(expected_tokens) > 0:
40
- next_expected_token = expected_tokens.pop(0)
41
  else:
42
  break
43
 
44
- if len(expected_tokens) > 0:
45
- print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:"
46
- f" `{self.tokenizer.decode(expected_tokens)}`. Here's what wav2vec thought it heard:"
47
- f"`{self.tokenizer.decode(logits.argmax(-1).tolist())}`")
48
- if not return_partial:
49
- return None
50
 
51
- return alignments
 
 
 
 
 
 
 
 
 
 
52
 
53
- def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3):
 
 
54
  if '[' not in expected_text:
55
  return audio
56
  splitted = expected_text.split('[')
@@ -58,33 +124,22 @@ class Wav2VecAlignment:
58
  for spl in splitted[1:]:
59
  assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
60
  fully_split.extend(spl.split(']'))
61
- # Remove any non-alphabetic character in the input text. This makes matching more likely.
62
- fully_split = [re.sub(r'[^a-zA-Z ]', '', s) for s in fully_split]
63
  # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
64
  non_redacted_intervals = []
65
  last_point = 0
66
  for i in range(len(fully_split)):
67
  if i % 2 == 0:
68
- non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1))
 
69
  last_point += len(fully_split[i])
70
 
71
  bare_text = ''.join(fully_split)
72
- alignments = self.align(audio, bare_text, audio_sample_rate, topk, return_partial=True)
73
- # If alignment fails, we will attempt to recover by assuming the remaining alignments consume the rest of the string.
74
- def get_alignment(i):
75
- if i >= len(alignments):
76
- return audio.shape[-1]
77
 
78
  output_audio = []
79
  for nri in non_redacted_intervals:
80
  start, stop = nri
81
- output_audio.append(audio[:, get_alignment(start):get_alignment(stop)])
82
  return torch.cat(output_audio, dim=-1)
83
 
84
-
85
- if __name__ == '__main__':
86
- some_audio = load_audio('../../results/train_dotrice_0.wav', 24000)
87
- aligner = Wav2VecAlignment()
88
- text = "[God fucking damn it I'm so angry] The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them."
89
- redact = aligner.redact(some_audio, text)
90
- torchaudio.save(f'test_output.wav', redact, 24000)
 
7
  from tortoise.utils.audio import load_audio
8
 
9
 
10
+ def max_alignment(s1, s2, skip_character='~', record={}):
11
+ """
12
+ A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
13
+ used to replace that character.
14
+
15
+ Finally got to use my DP skills!
16
+ """
17
+ assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
18
+ if len(s1) == 0:
19
+ return ''
20
+ if len(s2) == 0:
21
+ return skip_character * len(s1)
22
+ if s1 == s2:
23
+ return s1
24
+ if s1[0] == s2[0]:
25
+ return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)
26
+
27
+ take_s1_key = (len(s1), len(s2) - 1)
28
+ if take_s1_key in record:
29
+ take_s1, take_s1_score = record[take_s1_key]
30
+ else:
31
+ take_s1 = max_alignment(s1, s2[1:], skip_character, record)
32
+ take_s1_score = len(take_s1.replace(skip_character, ''))
33
+ record[take_s1_key] = (take_s1, take_s1_score)
34
+
35
+ take_s2_key = (len(s1) - 1, len(s2))
36
+ if take_s2_key in record:
37
+ take_s2, take_s2_score = record[take_s2_key]
38
+ else:
39
+ take_s2 = max_alignment(s1[1:], s2, skip_character, record)
40
+ take_s2_score = len(take_s2.replace(skip_character, ''))
41
+ record[take_s2_key] = (take_s2, take_s2_score)
42
+
43
+ return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2
44
+
45
+
46
  class Wav2VecAlignment:
47
+ """
48
+ Uses wav2vec2 to perform audio<->text alignment.
49
+ """
50
  def __init__(self):
51
  self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
52
  self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
53
  self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols')
54
 
55
+ def align(self, audio, expected_text, audio_sample_rate=24000):
56
  orig_len = audio.shape[-1]
57
 
58
  with torch.no_grad():
 
64
  self.model = self.model.cpu()
65
 
66
  logits = logits[0]
67
+ pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())
68
+
69
+ fixed_expectation = max_alignment(expected_text, pred_string)
70
  w2v_compression = orig_len // logits.shape[0]
71
+ expected_tokens = self.tokenizer.encode(fixed_expectation)
72
+ expected_chars = list(fixed_expectation)
73
  if len(expected_tokens) == 1:
74
  return [0] # The alignment is simple; there is only one token.
75
  expected_tokens.pop(0) # The first token is a given.
76
+ expected_chars.pop(0)
77
+
78
  alignments = [0]
79
+ def pop_till_you_win():
80
+ if len(expected_tokens) == 0:
81
+ return None
82
+ popped = expected_tokens.pop(0)
83
+ popped_char = expected_chars.pop(0)
84
+ while popped_char == '~':
85
+ alignments.append(-1)
86
+ if len(expected_tokens) == 0:
87
+ return None
88
+ popped = expected_tokens.pop(0)
89
+ popped_char = expected_chars.pop(0)
90
+ return popped
91
+
92
+ next_expected_token = pop_till_you_win()
93
  for i, logit in enumerate(logits):
94
+ top = logit.argmax()
95
+ if next_expected_token == top:
96
  alignments.append(i * w2v_compression)
97
  if len(expected_tokens) > 0:
98
+ next_expected_token = pop_till_you_win()
99
  else:
100
  break
101
 
102
+ pop_till_you_win()
103
+ assert len(expected_tokens) == 0, "This shouldn't happen. My coding sucks."
 
 
 
 
104
 
105
+ # Now fix up alignments. Anything with -1 should be interpolated.
106
+ alignments.append(orig_len) # This'll get removed but makes the algorithm below more readable.
107
+ for i in range(len(alignments)):
108
+ if alignments[i] == -1:
109
+ for j in range(i+1, len(alignments)):
110
+ if alignments[j] != -1:
111
+ next_found_token = j
112
+ break
113
+ for j in range(i, next_found_token):
114
+ gap = alignments[next_found_token] - alignments[i-1]
115
+ alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1]
116
 
117
+ return alignments[:-1]
118
+
119
+ def redact(self, audio, expected_text, audio_sample_rate=24000):
120
  if '[' not in expected_text:
121
  return audio
122
  splitted = expected_text.split('[')
 
124
  for spl in splitted[1:]:
125
  assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
126
  fully_split.extend(spl.split(']'))
127
+
 
128
  # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
129
  non_redacted_intervals = []
130
  last_point = 0
131
  for i in range(len(fully_split)):
132
  if i % 2 == 0:
133
+ end_interval = max(0, last_point + len(fully_split[i]) - 1)
134
+ non_redacted_intervals.append((last_point, end_interval))
135
  last_point += len(fully_split[i])
136
 
137
  bare_text = ''.join(fully_split)
138
+ alignments = self.align(audio, bare_text, audio_sample_rate)
 
 
 
 
139
 
140
  output_audio = []
141
  for nri in non_redacted_intervals:
142
  start, stop = nri
143
+ output_audio.append(audio[:, alignments[start]:alignments[stop]])
144
  return torch.cat(output_audio, dim=-1)
145