"""Beep profanity words in audio using one of the Hubert-compatible ASR models. """ import argparse import re import logging import warnings import transformers import torch import numpy as np try: import soundfile except (ImportError, OSError): warnings.warn("Cannot import soundfile. Reading/writing files will be unavailable") log = logging.getLogger(__name__) class HubertBeeper: PROFANITY = ["fuck", "shit", "piss"] def __init__(self, model_name="facebook/hubert-large-ls960-ft"): log.debug("Loading model: %s", model_name) self.model_name = model_name self.model = transformers.AutoModelForCTC.from_pretrained(model_name) self.model.eval() self.feature_extractor = transformers.AutoFeatureExtractor.from_pretrained(model_name) self.tokenizer = transformers.AutoTokenizer.from_pretrained(model_name) self.processor = transformers.Wav2Vec2Processor( feature_extractor=self.feature_extractor, tokenizer=self.tokenizer) def asr(self, waveform, sample_rate): features = self.processor([waveform], sampling_rate=sample_rate) features = torch.tensor(features.input_values) output = self.model(features) return output def f_beep(self, sound_file_path: str) -> np.array: wav, sample_rate = soundfile.read(sound_file_path) text, result_wav = self.f_beep_waveform(wav, sample_rate) return result_wav def f_beep_waveform(self, wav: np.array, sample_rate: int) -> np.array: model_output = self.asr(wav, sample_rate) text, spans = find_words_in_audio(model_output, self.processor, self.PROFANITY) number_of_frames = model_output.logits.shape[1] frame_size = len(wav) / number_of_frames # Mask offsensive parts of the audio for frame_begin, frame_end in spans: begin = round(frame_begin * frame_size) end = round(frame_end * frame_size) self.generate_beep(wav, begin, end) return text, wav def generate_beep(self, wav, begin, end): """Generate a beep over the selected region in audio. Modifies waveform in place. """ # Silence sounds better than beeps for i in range(begin, end): wav[i] = 0 def find_words_in_audio(model_output, processor, words): """Return all frame spans that matches any of the `words`. """ result_spans = [] token_ids = model_output.logits.argmax(dim=-1)[0] vocab = processor.tokenizer.get_vocab() text, offsets = decode_output_with_offsets(token_ids, vocab) text = text.lower() log.debug("ASR text: %s", text) for word in words: result_spans += find_spans(text, offsets, word) log.debug("Spans: %s", result_spans) return text, result_spans def find_spans(text, offsets, word): """Return all frame indexes that correspond to the given `word`. """ spans = [] pattern = r"\b" + re.escape(word) + r"\b" for match in re.finditer(pattern, text): a = match.start() b = match.end() + 1 start_frame = offsets[a] end_frame = offsets[b] if b < len(offsets) else -1 spans.append((start_frame, end_frame)) return spans def decode_output_with_offsets(decoded_token_ids, vocab): """Given list of decoded tokens, return text and time offsets that correspond to each character in the text. Args: decoded_token_ids (List[int]): list of token ids. The length of the list should be equal to the number of audio frames. vocab (Dict[str, int]): model's vocabulary. Returns: Tuple[str, List[int]], where `str` is a decoded text, `List[int]` is a starting frame indexes for every character in text. """ token_by_index = {v: k for k, v in vocab.items()} prev_token = None result_string = [] result_offset = [] for i, token_id in enumerate(decoded_token_ids): token_id = token_id.item() if token_id == 0: continue token = token_by_index[token_id] if token == prev_token: continue result_string.append(token) result_offset.append(i) prev_token = token result_string = "".join(result_string).replace("|", " ") assert len(result_string) == len(result_offset) return result_string, result_offset def main(): parser = argparse.ArgumentParser() parser.add_argument("input") parser.add_argument("-o", "--output") parser.add_argument("-v", "--verbose", action="store_true") parser.add_argument("--model", default="facebook/hubert-large-ls960-ft") args = parser.parse_args() logging.basicConfig(level=logging.DEBUG if args.verbose else logging.INFO) beeper = HubertBeeper(args.model) result = beeper.f_beep(args.input) output = args.output or "result.wav" soundfile.write(output, result, 16000) print(f"Saved to {output}") if __name__ == "__main__": main()