osyvokon commited on
Commit
7165c71
·
1 Parent(s): 84ba7aa

Add HuBERT-fbeeper demo

Browse files
Files changed (2) hide show
  1. fbeeper_hubert.py +157 -0
  2. requirements.txt +7 -0
fbeeper_hubert.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Beep profanity words in audio using one of the Hubert-compatible ASR models.
2
+ """
3
+
4
+ import argparse
5
+ import re
6
+ import logging
7
+
8
+ import soundfile
9
+ import transformers
10
+ import torch
11
+ import numpy as np
12
+
13
+
14
+ log = logging.getLogger(__name__)
15
+
16
+
17
+ class HubertBeeper:
18
+
19
+ PROFANITY = ["fuck", "shit", "piss"]
20
+
21
+ def __init__(self, model_name="facebook/hubert-large-ls960-ft"):
22
+ log.debug("Loading model: %s", model_name)
23
+ self.model_name = model_name
24
+ self.model = transformers.AutoModelForCTC.from_pretrained(model_name)
25
+ self.model.eval()
26
+ self.feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name)
27
+ self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name)
28
+ self.processor = transformers.Wav2Vec2Processor(
29
+ feature_extractor=self.feature_extractor, tokenizer=self.tokenizer)
30
+
31
+ def asr(self, waveform, sample_rate):
32
+ features = self.processor([waveform], sampling_rate=sample_rate)
33
+ features = torch.tensor(features.input_values)
34
+ output = self.model(features)
35
+ return output
36
+
37
+ def f_beep(self, sound_file_path: str) -> np.array:
38
+ wav, sample_rate = soundfile.read(sound_file_path)
39
+ text, result_wav = self.f_beep_waveform(wav, sample_rate)
40
+ return result_wav
41
+
42
+ def f_beep_waveform(self, wav: np.array, sample_rate: int) -> np.array:
43
+ model_output = self.asr(wav, sample_rate)
44
+ text, spans = find_words_in_audio(model_output, self.processor, self.PROFANITY)
45
+ number_of_frames = model_output.logits.shape[1]
46
+ frame_size = len(wav) / number_of_frames
47
+
48
+ # Mask offsensive parts of the audio
49
+ for frame_begin, frame_end in spans:
50
+ begin = round(frame_begin * frame_size)
51
+ end = round(frame_end * frame_size)
52
+ self.generate_beep(wav, begin, end)
53
+
54
+ return text, wav
55
+
56
+ def generate_beep(self, wav, begin, end):
57
+ """Generate a beep over the selected region in audio.
58
+
59
+ Modifies waveform in place.
60
+ """
61
+
62
+ # Silence sounds better than beeps
63
+ for i in range(begin, end):
64
+ wav[i] = 0
65
+
66
+ def find_words_in_audio(model_output, processor, words):
67
+ """Return all frame spans that matches any of the `words`.
68
+ """
69
+
70
+ result_spans = []
71
+
72
+ token_ids = model_output.logits.argmax(dim=-1)[0]
73
+ vocab = processor.tokenizer.get_vocab()
74
+ text, offsets = decode_output_with_offsets(token_ids, vocab)
75
+ text = text.lower()
76
+ log.debug("ASR text: %s", text)
77
+
78
+ for word in words:
79
+ result_spans += find_spans(text, offsets, word)
80
+
81
+ log.debug("Spans: %s", result_spans)
82
+
83
+ return text, result_spans
84
+
85
+
86
+ def find_spans(text, offsets, word):
87
+ """Return all frame indexes that correspond to the given `word`.
88
+ """
89
+ spans = []
90
+ pattern = r"\b" + re.escape(word) + r"\b"
91
+ for match in re.finditer(pattern, text):
92
+ a = match.start()
93
+ b = match.end() + 1
94
+ start_frame = offsets[a]
95
+ end_frame = offsets[b] if b < len(offsets) else -1
96
+ spans.append((start_frame, end_frame))
97
+ return spans
98
+
99
+
100
+ def decode_output_with_offsets(decoded_token_ids, vocab):
101
+ """Given list of decoded tokens, return text and
102
+ time offsets that correspond to each character in the text.
103
+
104
+ Args:
105
+ decoded_token_ids (List[int]): list of token ids.
106
+ The length of the list should be equal to the number
107
+ of audio frames.
108
+ vocab (Dict[str, int]): model's vocabulary.
109
+
110
+ Returns:
111
+ Tuple[str, List[int]], where
112
+ `str` is a decoded text,
113
+ `List[int]` is a starting frame indexes for
114
+ every character in text.
115
+ """
116
+ token_by_index = {v: k for k, v in vocab.items()}
117
+ prev_token = None
118
+ result_string = []
119
+ result_offset = []
120
+ for i, token_id in enumerate(decoded_token_ids):
121
+ token_id = token_id.item()
122
+ if token_id == 0:
123
+ continue
124
+
125
+ token = token_by_index[token_id]
126
+ if token == prev_token:
127
+ continue
128
+
129
+ result_string.append(token)
130
+ result_offset.append(i)
131
+ prev_token = token
132
+
133
+ result_string = "".join(result_string).replace("|", " ")
134
+ assert len(result_string) == len(result_offset)
135
+ return result_string, result_offset
136
+
137
+
138
+
139
+ def main():
140
+ parser = argparse.ArgumentParser()
141
+ parser.add_argument("input")
142
+ parser.add_argument("-o", "--output")
143
+ parser.add_argument("-v", "--verbose", action="store_true")
144
+ parser.add_argument("--model", default="facebook/hubert-large-ls960-ft")
145
+ args = parser.parse_args()
146
+ logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO)
147
+
148
+ beeper = HubertBeeper(args.model)
149
+ result = beeper.f_beep(args.input)
150
+
151
+ output = args.output or "result.wav"
152
+ soundfile.write(output, result, 16000)
153
+ print(f"Saved to {output}")
154
+
155
+
156
+ if __name__ == "__main__":
157
+ main()
requirements.txt ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ transformers==4.12.3
2
+ pydub
3
+ soundfile
4
+ librosa
5
+ unidecode
6
+ inflect
7
+ torchaudio