|
import re |
|
|
|
import torch |
|
import torchaudio |
|
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor |
|
|
|
from tortoise.utils.audio import load_audio |
|
|
|
|
|
class Wav2VecAlignment: |
|
def __init__(self): |
|
self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu() |
|
self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h") |
|
self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron_symbols') |
|
|
|
def align(self, audio, expected_text, audio_sample_rate=24000, topk=3, return_partial=False): |
|
orig_len = audio.shape[-1] |
|
|
|
with torch.no_grad(): |
|
self.model = self.model.cuda() |
|
audio = audio.to('cuda') |
|
audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000) |
|
clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7) |
|
logits = self.model(clip_norm).logits |
|
self.model = self.model.cpu() |
|
|
|
logits = logits[0] |
|
w2v_compression = orig_len // logits.shape[0] |
|
expected_tokens = self.tokenizer.encode(expected_text) |
|
if len(expected_tokens) == 1: |
|
return [0] |
|
expected_tokens.pop(0) |
|
next_expected_token = expected_tokens.pop(0) |
|
alignments = [0] |
|
for i, logit in enumerate(logits): |
|
top = logit.topk(topk).indices.tolist() |
|
if next_expected_token in top: |
|
alignments.append(i * w2v_compression) |
|
if len(expected_tokens) > 0: |
|
next_expected_token = expected_tokens.pop(0) |
|
else: |
|
break |
|
|
|
if len(expected_tokens) > 0: |
|
print(f"Alignment did not work. {len(expected_tokens)} were not found, with the following string un-aligned:" |
|
f" `{self.tokenizer.decode(expected_tokens)}`. Here's what wav2vec thought it heard:" |
|
f"`{self.tokenizer.decode(logits.argmax(-1).tolist())}`") |
|
if not return_partial: |
|
return None |
|
|
|
return alignments |
|
|
|
def redact(self, audio, expected_text, audio_sample_rate=24000, topk=3): |
|
if '[' not in expected_text: |
|
return audio |
|
splitted = expected_text.split('[') |
|
fully_split = [splitted[0]] |
|
for spl in splitted[1:]: |
|
assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.' |
|
fully_split.extend(spl.split(']')) |
|
|
|
fully_split = [re.sub(r'[^a-zA-Z ]', '', s) for s in fully_split] |
|
|
|
non_redacted_intervals = [] |
|
last_point = 0 |
|
for i in range(len(fully_split)): |
|
if i % 2 == 0: |
|
non_redacted_intervals.append((last_point, last_point + len(fully_split[i]) - 1)) |
|
last_point += len(fully_split[i]) |
|
|
|
bare_text = ''.join(fully_split) |
|
alignments = self.align(audio, bare_text, audio_sample_rate, topk, return_partial=True) |
|
|
|
def get_alignment(i): |
|
if i >= len(alignments): |
|
return audio.shape[-1] |
|
|
|
output_audio = [] |
|
for nri in non_redacted_intervals: |
|
start, stop = nri |
|
output_audio.append(audio[:, get_alignment(start):get_alignment(stop)]) |
|
return torch.cat(output_audio, dim=-1) |
|
|
|
|
|
if __name__ == '__main__': |
|
some_audio = load_audio('../../results/train_dotrice_0.wav', 24000) |
|
aligner = Wav2VecAlignment() |
|
text = "[God fucking damn it I'm so angry] The expressiveness of autoregressive transformers is literally nuts! I absolutely adore them." |
|
redact = aligner.redact(some_audio, text) |
|
torchaudio.save(f'test_output.wav', redact, 24000) |
|
|