Upload seamless_communication/cli/toxicity/asr_etox.py with huggingface_hub
Browse files
seamless_communication/cli/toxicity/asr_etox.py
ADDED
@@ -0,0 +1,255 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import tempfile
|
9 |
+
import typing as tp
|
10 |
+
import torchaudio
|
11 |
+
from tqdm import tqdm
|
12 |
+
from seamless_communication.cli.eval_utils.compute_metrics import init_whisper_model
|
13 |
+
from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
|
14 |
+
from seamless_communication.inference.translator import Modality
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from pathlib import Path
|
18 |
+
from seamless_communication.inference import Translator
|
19 |
+
from fairseq2.data import Collater, DataPipeline, FileMapper
|
20 |
+
from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
|
21 |
+
from fairseq2.data.text import StrSplitter, read_text
|
22 |
+
from fairseq2.typing import DataType, Device
|
23 |
+
|
24 |
+
from seamless_communication.toxicity import load_etox_bad_word_checker
|
25 |
+
|
26 |
+
from whisper.model import Whisper
|
27 |
+
|
28 |
+
import logging
|
29 |
+
|
30 |
+
logging.basicConfig(
|
31 |
+
level=logging.INFO,
|
32 |
+
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
33 |
+
)
|
34 |
+
|
35 |
+
logger = logging.getLogger(__name__)
|
36 |
+
|
37 |
+
|
38 |
+
def main() -> None:
|
39 |
+
parser = argparse.ArgumentParser(
|
40 |
+
description="ASR ETOX will compute the toxicity level of speech inputs."
|
41 |
+
)
|
42 |
+
parser.add_argument(
|
43 |
+
"data_file",
|
44 |
+
type=Path,
|
45 |
+
help="Path to the input TSV manifest that list the audio files.",
|
46 |
+
)
|
47 |
+
parser.add_argument(
|
48 |
+
"output_file",
|
49 |
+
type=Path,
|
50 |
+
help="Path to a TSV file where to save the results.",
|
51 |
+
)
|
52 |
+
parser.add_argument(
|
53 |
+
"--lang",
|
54 |
+
type=str,
|
55 |
+
help="Language, language of the speech to transcribe",
|
56 |
+
required=True,
|
57 |
+
)
|
58 |
+
parser.add_argument(
|
59 |
+
"--audio_root_dir",
|
60 |
+
type=str,
|
61 |
+
help="Root directory for the audio filenames in the data file.",
|
62 |
+
)
|
63 |
+
parser.add_argument(
|
64 |
+
"--audio_column",
|
65 |
+
type=str,
|
66 |
+
help="Name of the column where the audiofile is listed in the input tsv.",
|
67 |
+
default="audio",
|
68 |
+
)
|
69 |
+
parser.add_argument(
|
70 |
+
"--model_name",
|
71 |
+
type=str,
|
72 |
+
help=(
|
73 |
+
"Base model name (`seamlessM4T_medium`, "
|
74 |
+
"`seamlessM4T_large`, `seamlessM4T_v2_large`), "
|
75 |
+
" or whisper model, e.g. 'whisper_large'"
|
76 |
+
),
|
77 |
+
default="seamlessM4T_v2_large",
|
78 |
+
)
|
79 |
+
parser.add_argument(
|
80 |
+
"--batch_size",
|
81 |
+
type=int,
|
82 |
+
help="Inference batch size.",
|
83 |
+
default=4,
|
84 |
+
)
|
85 |
+
parser.add_argument(
|
86 |
+
"--n_parallel",
|
87 |
+
type=int,
|
88 |
+
help="Number of data loading in parallel.",
|
89 |
+
default=4,
|
90 |
+
)
|
91 |
+
args, _unknown = parser.parse_known_args()
|
92 |
+
|
93 |
+
if torch.cuda.is_available():
|
94 |
+
device = torch.device("cuda:0")
|
95 |
+
dtype = torch.float16
|
96 |
+
else:
|
97 |
+
device = torch.device("cpu")
|
98 |
+
dtype = torch.float32
|
99 |
+
|
100 |
+
whisper_model = None
|
101 |
+
translator = None
|
102 |
+
is_whisper = False
|
103 |
+
|
104 |
+
if args.model_name.startswith("whisper_"):
|
105 |
+
logger.info("loading whisper model.")
|
106 |
+
_, model_name = args.model_name.split("_", maxsplit=1)
|
107 |
+
whisper_model = init_whisper_model(device, model_name)
|
108 |
+
is_whisper = True
|
109 |
+
else:
|
110 |
+
logger.info(f"loading {args.model_name} model.")
|
111 |
+
translator = Translator(
|
112 |
+
args.model_name,
|
113 |
+
None,
|
114 |
+
device,
|
115 |
+
text_tokenizer=None,
|
116 |
+
dtype=dtype,
|
117 |
+
input_modality=Modality.SPEECH,
|
118 |
+
output_modality=Modality.TEXT,
|
119 |
+
apply_mintox=False,
|
120 |
+
)
|
121 |
+
|
122 |
+
logger.info("loading etox.")
|
123 |
+
bad_word_checker = load_etox_bad_word_checker("mintox")
|
124 |
+
|
125 |
+
pipeline = build_data_pipeline(
|
126 |
+
data_file=args.data_file,
|
127 |
+
audio_root_dir=args.audio_root_dir,
|
128 |
+
batch_size=args.batch_size,
|
129 |
+
is_whisper=is_whisper,
|
130 |
+
device=device,
|
131 |
+
dtype=dtype,
|
132 |
+
n_parallel=args.n_parallel,
|
133 |
+
audio_column=args.audio_column,
|
134 |
+
)
|
135 |
+
|
136 |
+
logger.info("running ASR-ETOX.")
|
137 |
+
with open(args.output_file, "w", encoding="utf-8") as outf:
|
138 |
+
print("text", "toxicity", "bad_words", file=outf, sep="\t")
|
139 |
+
for example in tqdm(pipeline, unit="line"):
|
140 |
+
texts = get_text(
|
141 |
+
lang=args.lang,
|
142 |
+
example=example,
|
143 |
+
whisper_model=whisper_model,
|
144 |
+
translator=translator,
|
145 |
+
audio_column=args.audio_column,
|
146 |
+
)
|
147 |
+
for t in texts:
|
148 |
+
bad_words = bad_word_checker.get_bad_words(
|
149 |
+
text=str(t),
|
150 |
+
lang=args.lang,
|
151 |
+
)
|
152 |
+
print(
|
153 |
+
t,
|
154 |
+
len(bad_words),
|
155 |
+
",".join(bad_words),
|
156 |
+
file=outf,
|
157 |
+
sep="\t",
|
158 |
+
)
|
159 |
+
|
160 |
+
|
161 |
+
def get_text(
|
162 |
+
lang: str,
|
163 |
+
example: tp.Dict[str, tp.Any],
|
164 |
+
whisper_model: Whisper,
|
165 |
+
translator: Translator,
|
166 |
+
audio_column: str,
|
167 |
+
):
|
168 |
+
if whisper_model:
|
169 |
+
with tempfile.NamedTemporaryFile(suffix=".wav") as temp:
|
170 |
+
torchaudio.save(
|
171 |
+
temp.name,
|
172 |
+
example[audio_column]["data"]["waveform"]["seqs"][0]
|
173 |
+
.transpose(0, 1)
|
174 |
+
.cpu(),
|
175 |
+
int(example[audio_column]["data"]["sample_rate"][0]),
|
176 |
+
format="wav",
|
177 |
+
)
|
178 |
+
results = whisper_model.transcribe(
|
179 |
+
temp.name,
|
180 |
+
language=LANG3_LANG2[lang],
|
181 |
+
)
|
182 |
+
return [results["text"]]
|
183 |
+
else:
|
184 |
+
(text_output, _speech_output) = translator.predict(
|
185 |
+
example[audio_column]["data"]["fbank"],
|
186 |
+
"ASR",
|
187 |
+
lang,
|
188 |
+
src_lang=lang,
|
189 |
+
)
|
190 |
+
return text_output
|
191 |
+
|
192 |
+
|
193 |
+
def build_data_pipeline(
|
194 |
+
data_file: Path,
|
195 |
+
audio_root_dir: str,
|
196 |
+
batch_size: int,
|
197 |
+
is_whisper: bool,
|
198 |
+
device: Device,
|
199 |
+
dtype: DataType,
|
200 |
+
audio_column: str = "audio",
|
201 |
+
n_parallel: int = 4,
|
202 |
+
) -> DataPipeline:
|
203 |
+
with data_file.open("r", encoding="utf-8") as f:
|
204 |
+
header = f.readline().strip("\n").split("\t")
|
205 |
+
|
206 |
+
split_tsv = StrSplitter(names=header)
|
207 |
+
|
208 |
+
pipeline_builder = read_text(data_file, rtrim=True).skip(1).map(split_tsv)
|
209 |
+
|
210 |
+
map_file = FileMapper(root_dir=audio_root_dir, cached_fd_count=10)
|
211 |
+
|
212 |
+
pipeline_builder.map(
|
213 |
+
map_file,
|
214 |
+
selector=audio_column,
|
215 |
+
num_parallel_calls=n_parallel,
|
216 |
+
)
|
217 |
+
|
218 |
+
decode_audio = AudioDecoder(dtype=torch.float32, device=device)
|
219 |
+
|
220 |
+
convert_to_fbank = WaveformToFbankConverter(
|
221 |
+
num_mel_bins=80,
|
222 |
+
waveform_scale=2**15,
|
223 |
+
channel_last=True,
|
224 |
+
standardize=True,
|
225 |
+
device=device,
|
226 |
+
dtype=dtype,
|
227 |
+
)
|
228 |
+
|
229 |
+
# get tensor in waveform
|
230 |
+
steps = [decode_audio]
|
231 |
+
if not is_whisper:
|
232 |
+
# also get the fbanks
|
233 |
+
steps.append(convert_to_fbank)
|
234 |
+
|
235 |
+
pipeline_builder.map(
|
236 |
+
steps,
|
237 |
+
selector=f"{audio_column}.data",
|
238 |
+
num_parallel_calls=n_parallel,
|
239 |
+
)
|
240 |
+
|
241 |
+
if is_whisper:
|
242 |
+
# no batching for whisper
|
243 |
+
pipeline_builder.bucket(bucket_size=batch_size)
|
244 |
+
|
245 |
+
collate = Collater(pad_value=0, pad_to_multiple=1)
|
246 |
+
|
247 |
+
pipeline_builder.map(collate, num_parallel_calls=n_parallel)
|
248 |
+
|
249 |
+
pipeline_builder.prefetch(4)
|
250 |
+
|
251 |
+
return pipeline_builder.and_return()
|
252 |
+
|
253 |
+
|
254 |
+
if __name__ == "__main__":
|
255 |
+
main()
|