victan commited on
Commit
e38a9e2
1 Parent(s): a7cbdc9

Upload seamless_communication/cli/toxicity/asr_etox.py with huggingface_hub

Browse files
seamless_communication/cli/toxicity/asr_etox.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import tempfile
9
+ import typing as tp
10
+ import torchaudio
11
+ from tqdm import tqdm
12
+ from seamless_communication.cli.eval_utils.compute_metrics import init_whisper_model
13
+ from seamless_communication.cli.eval_utils.lang_mapping import LANG3_LANG2
14
+ from seamless_communication.inference.translator import Modality
15
+ import torch
16
+
17
+ from pathlib import Path
18
+ from seamless_communication.inference import Translator
19
+ from fairseq2.data import Collater, DataPipeline, FileMapper
20
+ from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
21
+ from fairseq2.data.text import StrSplitter, read_text
22
+ from fairseq2.typing import DataType, Device
23
+
24
+ from seamless_communication.toxicity import load_etox_bad_word_checker
25
+
26
+ from whisper.model import Whisper
27
+
28
+ import logging
29
+
30
+ logging.basicConfig(
31
+ level=logging.INFO,
32
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
33
+ )
34
+
35
+ logger = logging.getLogger(__name__)
36
+
37
+
38
+ def main() -> None:
39
+ parser = argparse.ArgumentParser(
40
+ description="ASR ETOX will compute the toxicity level of speech inputs."
41
+ )
42
+ parser.add_argument(
43
+ "data_file",
44
+ type=Path,
45
+ help="Path to the input TSV manifest that list the audio files.",
46
+ )
47
+ parser.add_argument(
48
+ "output_file",
49
+ type=Path,
50
+ help="Path to a TSV file where to save the results.",
51
+ )
52
+ parser.add_argument(
53
+ "--lang",
54
+ type=str,
55
+ help="Language, language of the speech to transcribe",
56
+ required=True,
57
+ )
58
+ parser.add_argument(
59
+ "--audio_root_dir",
60
+ type=str,
61
+ help="Root directory for the audio filenames in the data file.",
62
+ )
63
+ parser.add_argument(
64
+ "--audio_column",
65
+ type=str,
66
+ help="Name of the column where the audiofile is listed in the input tsv.",
67
+ default="audio",
68
+ )
69
+ parser.add_argument(
70
+ "--model_name",
71
+ type=str,
72
+ help=(
73
+ "Base model name (`seamlessM4T_medium`, "
74
+ "`seamlessM4T_large`, `seamlessM4T_v2_large`), "
75
+ " or whisper model, e.g. 'whisper_large'"
76
+ ),
77
+ default="seamlessM4T_v2_large",
78
+ )
79
+ parser.add_argument(
80
+ "--batch_size",
81
+ type=int,
82
+ help="Inference batch size.",
83
+ default=4,
84
+ )
85
+ parser.add_argument(
86
+ "--n_parallel",
87
+ type=int,
88
+ help="Number of data loading in parallel.",
89
+ default=4,
90
+ )
91
+ args, _unknown = parser.parse_known_args()
92
+
93
+ if torch.cuda.is_available():
94
+ device = torch.device("cuda:0")
95
+ dtype = torch.float16
96
+ else:
97
+ device = torch.device("cpu")
98
+ dtype = torch.float32
99
+
100
+ whisper_model = None
101
+ translator = None
102
+ is_whisper = False
103
+
104
+ if args.model_name.startswith("whisper_"):
105
+ logger.info("loading whisper model.")
106
+ _, model_name = args.model_name.split("_", maxsplit=1)
107
+ whisper_model = init_whisper_model(device, model_name)
108
+ is_whisper = True
109
+ else:
110
+ logger.info(f"loading {args.model_name} model.")
111
+ translator = Translator(
112
+ args.model_name,
113
+ None,
114
+ device,
115
+ text_tokenizer=None,
116
+ dtype=dtype,
117
+ input_modality=Modality.SPEECH,
118
+ output_modality=Modality.TEXT,
119
+ apply_mintox=False,
120
+ )
121
+
122
+ logger.info("loading etox.")
123
+ bad_word_checker = load_etox_bad_word_checker("mintox")
124
+
125
+ pipeline = build_data_pipeline(
126
+ data_file=args.data_file,
127
+ audio_root_dir=args.audio_root_dir,
128
+ batch_size=args.batch_size,
129
+ is_whisper=is_whisper,
130
+ device=device,
131
+ dtype=dtype,
132
+ n_parallel=args.n_parallel,
133
+ audio_column=args.audio_column,
134
+ )
135
+
136
+ logger.info("running ASR-ETOX.")
137
+ with open(args.output_file, "w", encoding="utf-8") as outf:
138
+ print("text", "toxicity", "bad_words", file=outf, sep="\t")
139
+ for example in tqdm(pipeline, unit="line"):
140
+ texts = get_text(
141
+ lang=args.lang,
142
+ example=example,
143
+ whisper_model=whisper_model,
144
+ translator=translator,
145
+ audio_column=args.audio_column,
146
+ )
147
+ for t in texts:
148
+ bad_words = bad_word_checker.get_bad_words(
149
+ text=str(t),
150
+ lang=args.lang,
151
+ )
152
+ print(
153
+ t,
154
+ len(bad_words),
155
+ ",".join(bad_words),
156
+ file=outf,
157
+ sep="\t",
158
+ )
159
+
160
+
161
+ def get_text(
162
+ lang: str,
163
+ example: tp.Dict[str, tp.Any],
164
+ whisper_model: Whisper,
165
+ translator: Translator,
166
+ audio_column: str,
167
+ ):
168
+ if whisper_model:
169
+ with tempfile.NamedTemporaryFile(suffix=".wav") as temp:
170
+ torchaudio.save(
171
+ temp.name,
172
+ example[audio_column]["data"]["waveform"]["seqs"][0]
173
+ .transpose(0, 1)
174
+ .cpu(),
175
+ int(example[audio_column]["data"]["sample_rate"][0]),
176
+ format="wav",
177
+ )
178
+ results = whisper_model.transcribe(
179
+ temp.name,
180
+ language=LANG3_LANG2[lang],
181
+ )
182
+ return [results["text"]]
183
+ else:
184
+ (text_output, _speech_output) = translator.predict(
185
+ example[audio_column]["data"]["fbank"],
186
+ "ASR",
187
+ lang,
188
+ src_lang=lang,
189
+ )
190
+ return text_output
191
+
192
+
193
+ def build_data_pipeline(
194
+ data_file: Path,
195
+ audio_root_dir: str,
196
+ batch_size: int,
197
+ is_whisper: bool,
198
+ device: Device,
199
+ dtype: DataType,
200
+ audio_column: str = "audio",
201
+ n_parallel: int = 4,
202
+ ) -> DataPipeline:
203
+ with data_file.open("r", encoding="utf-8") as f:
204
+ header = f.readline().strip("\n").split("\t")
205
+
206
+ split_tsv = StrSplitter(names=header)
207
+
208
+ pipeline_builder = read_text(data_file, rtrim=True).skip(1).map(split_tsv)
209
+
210
+ map_file = FileMapper(root_dir=audio_root_dir, cached_fd_count=10)
211
+
212
+ pipeline_builder.map(
213
+ map_file,
214
+ selector=audio_column,
215
+ num_parallel_calls=n_parallel,
216
+ )
217
+
218
+ decode_audio = AudioDecoder(dtype=torch.float32, device=device)
219
+
220
+ convert_to_fbank = WaveformToFbankConverter(
221
+ num_mel_bins=80,
222
+ waveform_scale=2**15,
223
+ channel_last=True,
224
+ standardize=True,
225
+ device=device,
226
+ dtype=dtype,
227
+ )
228
+
229
+ # get tensor in waveform
230
+ steps = [decode_audio]
231
+ if not is_whisper:
232
+ # also get the fbanks
233
+ steps.append(convert_to_fbank)
234
+
235
+ pipeline_builder.map(
236
+ steps,
237
+ selector=f"{audio_column}.data",
238
+ num_parallel_calls=n_parallel,
239
+ )
240
+
241
+ if is_whisper:
242
+ # no batching for whisper
243
+ pipeline_builder.bucket(bucket_size=batch_size)
244
+
245
+ collate = Collater(pad_value=0, pad_to_multiple=1)
246
+
247
+ pipeline_builder.map(collate, num_parallel_calls=n_parallel)
248
+
249
+ pipeline_builder.prefetch(4)
250
+
251
+ return pipeline_builder.and_return()
252
+
253
+
254
+ if __name__ == "__main__":
255
+ main()