File size: 6,376 Bytes
b36e9ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
import re

import torch
import torchaudio
from transformers import Wav2Vec2ForCTC, Wav2Vec2FeatureExtractor, Wav2Vec2CTCTokenizer, Wav2Vec2Processor

from tortoise.utils.audio import load_audio


def max_alignment(s1, s2, skip_character='~', record=None):
    """
    A clever function that aligns s1 to s2 as best it can. Wherever a character from s1 is not found in s2, a '~' is
    used to replace that character.

    Finally got to use my DP skills!
    """
    if record is None:
        record = {}
    assert skip_character not in s1, f"Found the skip character {skip_character} in the provided string, {s1}"
    if len(s1) == 0:
        return ''
    if len(s2) == 0:
        return skip_character * len(s1)
    if s1 == s2:
        return s1
    if s1[0] == s2[0]:
        return s1[0] + max_alignment(s1[1:], s2[1:], skip_character, record)

    take_s1_key = (len(s1), len(s2) - 1)
    if take_s1_key in record:
        take_s1, take_s1_score = record[take_s1_key]
    else:
        take_s1 = max_alignment(s1, s2[1:], skip_character, record)
        take_s1_score = len(take_s1.replace(skip_character, ''))
        record[take_s1_key] = (take_s1, take_s1_score)

    take_s2_key = (len(s1) - 1, len(s2))
    if take_s2_key in record:
        take_s2, take_s2_score = record[take_s2_key]
    else:
        take_s2 = max_alignment(s1[1:], s2, skip_character, record)
        take_s2_score = len(take_s2.replace(skip_character, ''))
        record[take_s2_key] = (take_s2, take_s2_score)

    return take_s1 if take_s1_score > take_s2_score else skip_character + take_s2


class Wav2VecAlignment:
    """
    Uses wav2vec2 to perform audio<->text alignment.
    """
    def __init__(self, device='cuda' if not torch.backends.mps.is_available() else 'mps'):
        self.model = Wav2Vec2ForCTC.from_pretrained("jbetker/wav2vec2-large-robust-ft-libritts-voxpopuli").cpu()
        self.feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(f"facebook/wav2vec2-large-960h")
        self.tokenizer = Wav2Vec2CTCTokenizer.from_pretrained('jbetker/tacotron-symbols')
        self.device = device

    def align(self, audio, expected_text, audio_sample_rate=24000):
        orig_len = audio.shape[-1]

        with torch.no_grad():
            self.model = self.model.to(self.device)
            audio = audio.to(self.device)
            audio = torchaudio.functional.resample(audio, audio_sample_rate, 16000)
            clip_norm = (audio - audio.mean()) / torch.sqrt(audio.var() + 1e-7)
            logits = self.model(clip_norm).logits
            self.model = self.model.cpu()

        logits = logits[0]
        pred_string = self.tokenizer.decode(logits.argmax(-1).tolist())

        fixed_expectation = max_alignment(expected_text.lower(), pred_string)
        w2v_compression = orig_len // logits.shape[0]
        expected_tokens = self.tokenizer.encode(fixed_expectation)
        expected_chars = list(fixed_expectation)
        if len(expected_tokens) == 1:
            return [0]  # The alignment is simple; there is only one token.
        expected_tokens.pop(0)  # The first token is a given.
        expected_chars.pop(0)

        alignments = [0]
        def pop_till_you_win():
            if len(expected_tokens) == 0:
                return None
            popped = expected_tokens.pop(0)
            popped_char = expected_chars.pop(0)
            while popped_char == '~':
                alignments.append(-1)
                if len(expected_tokens) == 0:
                    return None
                popped = expected_tokens.pop(0)
                popped_char = expected_chars.pop(0)
            return popped

        next_expected_token = pop_till_you_win()
        for i, logit in enumerate(logits):
            top = logit.argmax()
            if next_expected_token == top:
                alignments.append(i * w2v_compression)
                if len(expected_tokens) > 0:
                    next_expected_token = pop_till_you_win()
                else:
                    break

        pop_till_you_win()
        if not (len(expected_tokens) == 0 and len(alignments) == len(expected_text)):
            torch.save([audio, expected_text], 'alignment_debug.pth')
            assert False, "Something went wrong with the alignment algorithm. I've dumped a file, 'alignment_debug.pth' to" \
                          "your current working directory. Please report this along with the file so it can get fixed."

        # Now fix up alignments. Anything with -1 should be interpolated.
        alignments.append(orig_len)  # This'll get removed but makes the algorithm below more readable.
        for i in range(len(alignments)):
            if alignments[i] == -1:
                for j in range(i+1, len(alignments)):
                    if alignments[j] != -1:
                        next_found_token = j
                        break
                for j in range(i, next_found_token):
                    gap = alignments[next_found_token] - alignments[i-1]
                    alignments[j] = (j-i+1) * gap // (next_found_token-i+1) + alignments[i-1]

        return alignments[:-1]

    def redact(self, audio, expected_text, audio_sample_rate=24000):
        if '[' not in expected_text:
            return audio
        splitted = expected_text.split('[')
        fully_split = [splitted[0]]
        for spl in splitted[1:]:
            assert ']' in spl, 'Every "[" character must be paired with a "]" with no nesting.'
            fully_split.extend(spl.split(']'))

        # At this point, fully_split is a list of strings, with every other string being something that should be redacted.
        non_redacted_intervals = []
        last_point = 0
        for i in range(len(fully_split)):
            if i % 2 == 0 and fully_split[i] != "": # Check for empty string fixes index error
                end_interval = max(0, last_point + len(fully_split[i]) - 1)
                non_redacted_intervals.append((last_point, end_interval))
            last_point += len(fully_split[i])

        bare_text = ''.join(fully_split)
        alignments = self.align(audio, bare_text, audio_sample_rate)

        output_audio = []
        for nri in non_redacted_intervals:
            start, stop = nri
            output_audio.append(audio[:, alignments[start]:alignments[stop]])
        return torch.cat(output_audio, dim=-1)