victan commited on
Commit
7ddea44
1 Parent(s): e6ea21a

Upload seamless_communication/cli/m4t/evaluate/evaluate.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/evaluate/evaluate.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import contextlib
9
+ import itertools
10
+ import logging
11
+ import subprocess
12
+ from argparse import Namespace
13
+ from dataclasses import dataclass
14
+ from pathlib import Path
15
+ from typing import Any, Dict, List, Optional, Tuple
16
+
17
+ import torch
18
+ import torchaudio
19
+ from fairseq2.data import Collater, DataPipeline, FileMapper
20
+ from fairseq2.data.audio import AudioDecoder, WaveformToFbankConverter
21
+ from fairseq2.data.text import StrSplitter, TextTokenizer, read_text
22
+ from fairseq2.data.typing import StringLike
23
+ from fairseq2.typing import DataType, Device
24
+ from torch import Tensor
25
+ from tqdm import tqdm
26
+
27
+ from seamless_communication.cli.eval_utils import (
28
+ compute_quality_metrics,
29
+ )
30
+ from seamless_communication.cli.m4t.predict import (
31
+ add_inference_arguments,
32
+ set_generation_opts,
33
+ )
34
+ from seamless_communication.inference import (
35
+ BatchedSpeechOutput,
36
+ Modality,
37
+ SequenceGeneratorOptions,
38
+ Translator,
39
+ )
40
+ from seamless_communication.models.unity import load_unity_text_tokenizer
41
+
42
+ logging.basicConfig(
43
+ level=logging.INFO,
44
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
45
+ )
46
+
47
+ logger = logging.getLogger(__name__)
48
+
49
+
50
+ @dataclass
51
+ class EvalContext:
52
+ task: str
53
+ """String representing the task. Valid choices are
54
+ "S2ST", "S2TT", "T2ST", "T2TT", "ASR"."""
55
+
56
+ input_modality: Modality
57
+ """The input modality of the task."""
58
+
59
+ output_modality: Modality
60
+ """The output modality of the task."""
61
+
62
+ model_name: str
63
+ """The name of the S2T UnitY model."""
64
+
65
+ data_file: Path
66
+ """The pathname of the test TSV data file."""
67
+
68
+ audio_root_dir: Optional[Path]
69
+ """The pathname of the directory under which
70
+ audio files are stored."""
71
+
72
+ target_lang: str
73
+ """The target translation language."""
74
+
75
+ source_lang: Optional[str]
76
+ """The source language."""
77
+
78
+ batch_size: int
79
+ """The batch size for model input."""
80
+
81
+ device: Device
82
+ """The device on which to run inference."""
83
+
84
+ dtype: DataType
85
+ """The data type with which to run inference."""
86
+
87
+ output_path: Path
88
+ """The pathname of the output directory to save
89
+ the evaluation results."""
90
+
91
+ ref_field: str
92
+ """The reference target text field to compute
93
+ the BLEU score against."""
94
+
95
+ text_generation_opts: SequenceGeneratorOptions
96
+ """Text generation hyperparameters."""
97
+
98
+ unit_generation_opts: Optional[SequenceGeneratorOptions]
99
+ """Unit generation hyperparameters, not applicable
100
+ for the NAR T2U decoder."""
101
+
102
+ unit_generation_ngram_filtering: bool
103
+ """If True, removes consecutive repeating ngrams
104
+ from the decoded unit output."""
105
+
106
+
107
+ def count_lines(filename: Path) -> int:
108
+ result = subprocess.run(["wc", "-l", filename], stdout=subprocess.PIPE)
109
+ return int(result.stdout.decode().split()[0])
110
+
111
+
112
+ def build_data_pipeline(
113
+ ctx: EvalContext,
114
+ text_tokenizer: TextTokenizer,
115
+ ) -> DataPipeline:
116
+ with open(ctx.data_file, "r") as f:
117
+ header = f.readline().strip("\n").split("\t")
118
+ first_example = f.readline().strip("\n").split("\t")
119
+
120
+ # TODO: This will be soon auto-tuned. Right now hand-tuned for devfair.
121
+ n_parallel = 4
122
+
123
+ split_tsv = StrSplitter(names=header)
124
+
125
+ pipeline_builder = read_text(ctx.data_file, rtrim=True).skip(1).map(split_tsv)
126
+
127
+ if ctx.input_modality == Modality.SPEECH:
128
+ assert ctx.audio_root_dir is not None
129
+
130
+ map_file = FileMapper(root_dir=ctx.audio_root_dir, cached_fd_count=10)
131
+
132
+ pipeline_builder.map(map_file, selector="audio", num_parallel_calls=n_parallel)
133
+
134
+ decode_audio = AudioDecoder(dtype=torch.float32, device=ctx.device)
135
+
136
+ convert_to_fbank = WaveformToFbankConverter(
137
+ num_mel_bins=80,
138
+ waveform_scale=2**15,
139
+ channel_last=True,
140
+ standardize=True,
141
+ device=ctx.device,
142
+ dtype=ctx.dtype,
143
+ )
144
+
145
+ pipeline_builder.map(
146
+ [decode_audio, convert_to_fbank],
147
+ selector="audio.data",
148
+ num_parallel_calls=n_parallel,
149
+ )
150
+ else:
151
+ if "src_lang" in header:
152
+ source_lang = first_example[header.index("src_lang")]
153
+ ctx.source_lang = source_lang
154
+ elif ctx.source_lang is None:
155
+ raise ValueError(
156
+ (
157
+ "'src_lang' is missing in the data_file"
158
+ "header and in the arguments."
159
+ )
160
+ )
161
+
162
+ token_encoder = text_tokenizer.create_encoder(
163
+ task="translation", lang=source_lang, mode="source", device=ctx.device
164
+ )
165
+ pipeline_builder.map(
166
+ [token_encoder],
167
+ selector="src_text",
168
+ num_parallel_calls=n_parallel,
169
+ )
170
+
171
+ pipeline_builder.bucket(bucket_size=ctx.batch_size)
172
+
173
+ collate = Collater(pad_value=0, pad_to_multiple=1)
174
+
175
+ pipeline_builder.map(collate, num_parallel_calls=n_parallel)
176
+
177
+ pipeline_builder.prefetch(4)
178
+
179
+ return pipeline_builder.and_return()
180
+
181
+
182
+ def adjust_output_for_corrupted_inputs(
183
+ valid_sequences: Tensor,
184
+ text_output: List[StringLike],
185
+ speech_output: Optional[BatchedSpeechOutput],
186
+ ) -> Tuple[List[StringLike], Optional[BatchedSpeechOutput]]:
187
+ adjusted_text_output: List[StringLike] = []
188
+ adjusted_speech_output: Optional[BatchedSpeechOutput] = None
189
+
190
+ if speech_output is not None:
191
+ assert (
192
+ len(text_output)
193
+ == len(speech_output.units)
194
+ == len(speech_output.audio_wavs)
195
+ )
196
+ adjusted_speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
197
+
198
+ batch_counter = 0
199
+ for is_valid in valid_sequences:
200
+ if is_valid:
201
+ adjusted_text_output.append(text_output[batch_counter])
202
+ if speech_output is not None:
203
+ assert adjusted_speech_output is not None
204
+ adjusted_speech_output.units.append(speech_output.units[batch_counter])
205
+ adjusted_speech_output.audio_wavs.append(
206
+ speech_output.audio_wavs[batch_counter]
207
+ )
208
+ batch_counter += 1
209
+ else:
210
+ # For the corrupted inputs, we save the following dummy outputs:
211
+ # empty string for text, empty list for units, 1 second of silence for audio.
212
+ adjusted_text_output.append("")
213
+ if adjusted_speech_output is not None:
214
+ sample_rate = adjusted_speech_output.sample_rate
215
+ adjusted_speech_output.units.append([])
216
+ adjusted_speech_output.audio_wavs.append(
217
+ torch.zeros(sample_rate).unsqueeze(0).unsqueeze(0)
218
+ )
219
+ return (
220
+ adjusted_text_output,
221
+ adjusted_speech_output,
222
+ )
223
+
224
+
225
+ def run_eval(
226
+ translator: Translator,
227
+ text_tokenizer: TextTokenizer,
228
+ ctx: EvalContext,
229
+ whisper_model_name: str,
230
+ ) -> None:
231
+ pipeline = build_data_pipeline(ctx, text_tokenizer)
232
+
233
+ total_steps = count_lines(ctx.data_file) - 1
234
+ progress_bar = tqdm(total=total_steps)
235
+
236
+ output_path = ctx.output_path / ctx.data_file.stem
237
+ output_path.mkdir(parents=True, exist_ok=True)
238
+
239
+ if ctx.output_modality == Modality.SPEECH:
240
+ waveforms_dir = output_path / f"waveform_{ctx.data_file.stem}"
241
+ waveforms_dir.mkdir(parents=True, exist_ok=True)
242
+
243
+ model_outputs_tsv = output_path / f"model-outputs-{ctx.data_file.stem}.txt"
244
+ unit_outputs_tsv = output_path / f"unit_output-{ctx.data_file.stem}.txt"
245
+ with open(model_outputs_tsv, "w") as hyp_file, open(
246
+ unit_outputs_tsv, "w"
247
+ ) if ctx.output_modality == Modality.SPEECH else contextlib.nullcontext(
248
+ itertools.repeat(None)
249
+ ) as unit_file:
250
+ sample_id = 0
251
+ if ctx.output_modality == Modality.SPEECH:
252
+ hyp_file.write("ref_tgt_text\tpred_tgt_text\tpred_tgt_audio\n")
253
+ else:
254
+ hyp_file.write("ref_tgt_text\tpred_tgt_text\n")
255
+ for example in pipeline:
256
+ valid_sequences: Optional[Tensor] = None
257
+ if ctx.input_modality == Modality.SPEECH:
258
+ src = example["audio"]["data"]["fbank"]
259
+ # Skip corrupted audio tensors.
260
+ valid_sequences = ~torch.any(
261
+ torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
262
+ )
263
+ if not valid_sequences.all():
264
+ logger.warning(
265
+ f"Sample IDs {sample_id} to {sample_id + ctx.batch_size} has some corrupted input."
266
+ )
267
+ src["seqs"] = src["seqs"][valid_sequences]
268
+ src["seq_lens"] = src["seq_lens"][valid_sequences]
269
+ else:
270
+ src = example["src_text"]
271
+
272
+ # Skip performing inference when the input is entirely corrupted.
273
+ if src["seqs"].numel() > 0:
274
+ (text_output, speech_output,) = translator.predict(
275
+ src,
276
+ ctx.task,
277
+ ctx.target_lang,
278
+ src_lang=ctx.source_lang,
279
+ text_generation_opts=ctx.text_generation_opts,
280
+ unit_generation_opts=ctx.unit_generation_opts,
281
+ unit_generation_ngram_filtering=ctx.unit_generation_ngram_filtering,
282
+ )
283
+ else:
284
+ text_output = []
285
+ if ctx.output_modality == Modality.SPEECH:
286
+ speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
287
+ else:
288
+ speech_output = None
289
+
290
+ if valid_sequences is not None and not valid_sequences.all():
291
+ (text_output, speech_output,) = adjust_output_for_corrupted_inputs(
292
+ valid_sequences,
293
+ text_output,
294
+ speech_output,
295
+ )
296
+
297
+ hyps = [str(s) for s in text_output]
298
+ refs = [str(s) for s in example[ctx.ref_field]]
299
+
300
+ for i in range(len(text_output)):
301
+ if ctx.output_modality == Modality.SPEECH:
302
+ assert speech_output is not None
303
+ u = speech_output.units[i]
304
+ str_units = [str(i) for i in u]
305
+ unit_file.write(" ".join(str_units) + "\n")
306
+ wav_fp = str(waveforms_dir / f"{sample_id}_pred.wav")
307
+ torchaudio.save(
308
+ wav_fp,
309
+ speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
310
+ sample_rate=speech_output.sample_rate,
311
+ )
312
+ hyp_file.write(f"{refs[i]}\t{hyps[i]}\t{wav_fp}\n")
313
+ else:
314
+ hyp_file.write(f"{refs[i]}\t{hyps[i]}\n")
315
+
316
+ sample_id += 1
317
+ progress_bar.update(1)
318
+
319
+ progress_bar.close()
320
+ logger.info(f"Processed {sample_id} samples")
321
+
322
+ compute_quality_metrics(
323
+ output_manifest_tsv_path=model_outputs_tsv,
324
+ output_path=output_path,
325
+ tgt_lang=ctx.target_lang,
326
+ task=ctx.task,
327
+ device=ctx.device,
328
+ whisper_model_name=whisper_model_name,
329
+ )
330
+
331
+
332
+ def main(optional_args: Optional[Dict[str, Any]] = None) -> None:
333
+ parser = argparse.ArgumentParser(
334
+ description="M4T evaluation for tasks supported by Translator."
335
+ )
336
+ parser.add_argument(
337
+ "--data_file", type=str, help="Data file (.tsv) to be evaluated."
338
+ )
339
+
340
+ parser = add_inference_arguments(parser)
341
+ parser.add_argument(
342
+ "--batch_size",
343
+ type=int,
344
+ help="Inference batch size.",
345
+ default=4,
346
+ )
347
+ parser.add_argument(
348
+ "--audio_root_dir",
349
+ type=str,
350
+ help="Root directory for the audio filenames in the data file.",
351
+ default="",
352
+ )
353
+ parser.add_argument(
354
+ "--ref_field",
355
+ type=str,
356
+ help="Reference target text field to compute the BLEU score against.",
357
+ default="tgt_text",
358
+ )
359
+ parser.add_argument(
360
+ "--whisper_model_name",
361
+ type=str,
362
+ help="Whisper model to be used for ASR-BLEU scoring",
363
+ default="large",
364
+ )
365
+ args, unknown = parser.parse_known_args()
366
+ default_args = vars(args)
367
+ default_args.update(optional_args) if optional_args else default_args
368
+ args = Namespace(**default_args)
369
+
370
+ if not args.data_file or not args.task or not args.tgt_lang:
371
+ raise Exception(
372
+ "Please provide required arguments for evaluation - data_file, task, tgt_lang"
373
+ )
374
+
375
+ if not Path(args.data_file).exists():
376
+ raise ValueError(f"Invalid data_file to be evaluated: {args.data_file}")
377
+
378
+ input_modality, output_modality = Translator.get_modalities_from_task_str(args.task)
379
+
380
+ if input_modality == Modality.SPEECH and not Path(args.audio_root_dir).exists():
381
+ raise ValueError(
382
+ f"Invalid audio_root_dir: {args.audio_root_dir} for speech input."
383
+ )
384
+
385
+ if torch.cuda.is_available():
386
+ device = torch.device("cuda:0")
387
+ dtype = torch.float16
388
+ else:
389
+ device = torch.device("cpu")
390
+ dtype = torch.float32
391
+
392
+ text_tokenizer = load_unity_text_tokenizer(args.model_name)
393
+
394
+ # TODO: Avoid loading the T2U model, vocoder when the output
395
+ # modality is text.
396
+ translator = Translator(
397
+ args.model_name,
398
+ args.vocoder_name,
399
+ device,
400
+ text_tokenizer=text_tokenizer,
401
+ dtype=dtype,
402
+ input_modality=input_modality,
403
+ output_modality=output_modality,
404
+ )
405
+
406
+ text_generation_opts, unit_generation_opts = set_generation_opts(args)
407
+
408
+ logger.info(f"{text_generation_opts=}")
409
+ logger.info(f"{unit_generation_opts=}")
410
+ logger.info(
411
+ f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
412
+ )
413
+
414
+ # fmt: off
415
+ ctx = EvalContext(
416
+ task=args.task,
417
+ input_modality=input_modality,
418
+ output_modality=output_modality,
419
+ model_name=args.model_name,
420
+ data_file=Path(args.data_file),
421
+ audio_root_dir=Path(args.audio_root_dir),
422
+ target_lang=args.tgt_lang,
423
+ source_lang=args.src_lang,
424
+ batch_size=args.batch_size,
425
+ device=device,
426
+ dtype=dtype,
427
+ ref_field=args.ref_field,
428
+ text_generation_opts=text_generation_opts,
429
+ unit_generation_opts=unit_generation_opts,
430
+ unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
431
+ output_path=args.output_path,
432
+ )
433
+ # fmt: on
434
+ logger.info(f"Running inference on {device=} with {dtype=}, {ctx.batch_size=}.")
435
+
436
+ run_eval(translator, text_tokenizer, ctx, args.whisper_model_name)
437
+
438
+
439
+ if __name__ == "__main__":
440
+ main()