victan commited on
Commit
7933050
1 Parent(s): 89c9a17

Upload seamless_communication/cli/expressivity/evaluate/pretssel_inference.py with huggingface_hub

Browse files
seamless_communication/cli/expressivity/evaluate/pretssel_inference.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+ import argparse
8
+ import contextlib
9
+ import logging
10
+ from argparse import Namespace
11
+ from pathlib import Path
12
+ from typing import Optional
13
+
14
+ import pandas as pd
15
+ import torch
16
+ import torchaudio
17
+ from fairseq2.data import Collater, DataPipeline, FileMapper
18
+ from fairseq2.data.audio import (
19
+ AudioDecoder,
20
+ WaveformToFbankConverter,
21
+ WaveformToFbankOutput,
22
+ )
23
+ from fairseq2.data.text import StrSplitter, read_text
24
+ from fairseq2.typing import DataType, Device
25
+ from sacrebleu.metrics import BLEU # type: ignore[attr-defined]
26
+ from torch import Tensor
27
+ from tqdm import tqdm
28
+
29
+ from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
30
+ PretsselGenerator,
31
+ )
32
+ from seamless_communication.cli.m4t.evaluate.evaluate import (
33
+ adjust_output_for_corrupted_inputs,
34
+ count_lines,
35
+ )
36
+ from seamless_communication.cli.m4t.predict import (
37
+ add_inference_arguments,
38
+ set_generation_opts,
39
+ )
40
+ from seamless_communication.inference import BatchedSpeechOutput, Translator
41
+ from seamless_communication.models.unity import (
42
+ load_gcmvn_stats,
43
+ load_unity_unit_tokenizer,
44
+ )
45
+ from seamless_communication.store import add_gated_assets
46
+
47
+ logging.basicConfig(
48
+ level=logging.INFO,
49
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
50
+ )
51
+
52
+ logger = logging.getLogger(__name__)
53
+
54
+
55
+ def build_data_pipeline(
56
+ args: Namespace,
57
+ device: Device,
58
+ dtype: DataType,
59
+ gcmvn_mean: Tensor,
60
+ gcmvn_std: Tensor,
61
+ ) -> DataPipeline:
62
+ with open(args.data_file, "r") as f:
63
+ header = f.readline().strip("\n").split("\t")
64
+ assert (
65
+ args.audio_field in header
66
+ ), f"Input file does not contain {args.audio_field} field"
67
+
68
+ n_parallel = 4
69
+
70
+ split_tsv = StrSplitter(names=header)
71
+
72
+ pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
73
+
74
+ assert args.audio_root_dir is not None
75
+
76
+ map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
77
+
78
+ pipeline_builder.map(
79
+ map_file, selector=args.audio_field, num_parallel_calls=n_parallel
80
+ )
81
+
82
+ decode_audio = AudioDecoder(dtype=torch.float32, device=device)
83
+
84
+ convert_to_fbank = WaveformToFbankConverter(
85
+ num_mel_bins=80,
86
+ waveform_scale=2**15,
87
+ channel_last=True,
88
+ standardize=False,
89
+ device=device,
90
+ dtype=dtype,
91
+ )
92
+
93
+ def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
94
+ fbank = data["fbank"]
95
+ std, mean = torch.std_mean(fbank, dim=0)
96
+ data["fbank"] = fbank.subtract(mean).divide(std)
97
+ data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
98
+ return data
99
+
100
+ pipeline_builder.map(
101
+ [decode_audio, convert_to_fbank, normalize_fbank],
102
+ selector=f"{args.audio_field}.data",
103
+ num_parallel_calls=n_parallel,
104
+ )
105
+
106
+ pipeline_builder.bucket(bucket_size=args.batch_size)
107
+
108
+ collate = Collater(pad_value=0, pad_to_multiple=1)
109
+
110
+ pipeline_builder.map(collate, num_parallel_calls=n_parallel)
111
+
112
+ pipeline_builder.prefetch(4)
113
+
114
+ return pipeline_builder.and_return()
115
+
116
+
117
+ def main() -> None:
118
+ parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference")
119
+ parser.add_argument(
120
+ "data_file", type=Path, help="Data file (.tsv) to be evaluated."
121
+ )
122
+
123
+ parser = add_inference_arguments(parser)
124
+ param = parser.add_argument(
125
+ "--gated-model-dir",
126
+ type=Path,
127
+ required=False,
128
+ help="SeamlessExpressive model directory.",
129
+ )
130
+ parser.add_argument(
131
+ "--batch_size",
132
+ type=int,
133
+ help="Inference batch size.",
134
+ default=4,
135
+ )
136
+ parser.add_argument(
137
+ "--audio_root_dir",
138
+ type=Path,
139
+ help="Root directory for the audio filenames in the data file.",
140
+ default="",
141
+ )
142
+ parser.add_argument(
143
+ "--audio_field",
144
+ type=str,
145
+ help="Field that includes the input audio file paths.",
146
+ default="src_audio",
147
+ )
148
+ parser.add_argument(
149
+ "--ref_field",
150
+ type=str,
151
+ help="Reference target text field to compute the BLEU score against.",
152
+ default=None,
153
+ )
154
+ parser.add_argument(
155
+ "--duration_factor",
156
+ type=float,
157
+ help="The duration factor for NAR T2U model.",
158
+ default=1.0,
159
+ )
160
+ parser.add_argument(
161
+ "--output_result_tsv",
162
+ type=bool,
163
+ help="Whether to output results in tsv format (for full-blown evaluation)",
164
+ default=True,
165
+ )
166
+ args = parser.parse_args()
167
+
168
+ if args.gated_model_dir:
169
+ add_gated_assets(args.gated_model_dir)
170
+
171
+ if torch.cuda.is_available():
172
+ device = torch.device("cuda:0")
173
+ dtype = torch.float16
174
+ else:
175
+ device = torch.device("cpu")
176
+ dtype = torch.float32
177
+
178
+ unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
179
+
180
+ _gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
181
+ gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
182
+ gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
183
+
184
+ pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
185
+
186
+ translator = Translator(
187
+ args.model_name,
188
+ vocoder_name_or_card=None,
189
+ device=device,
190
+ dtype=dtype,
191
+ )
192
+
193
+ text_generation_opts, unit_generation_opts = set_generation_opts(args)
194
+
195
+ logger.info(f"{text_generation_opts=}")
196
+ logger.info(f"{unit_generation_opts=}")
197
+ logger.info(
198
+ f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
199
+ )
200
+
201
+ pretssel_generator = PretsselGenerator(
202
+ args.vocoder_name,
203
+ vocab_info=unit_tokenizer.vocab_info,
204
+ device=device,
205
+ dtype=dtype,
206
+ )
207
+
208
+ total_steps = count_lines(args.data_file) - 1
209
+ progress_bar = tqdm(total=total_steps)
210
+
211
+ output_path = args.output_path / args.data_file.stem
212
+ output_path.mkdir(parents=True, exist_ok=True)
213
+
214
+ waveforms_dir = output_path / "waveform"
215
+ waveforms_dir.mkdir(parents=True, exist_ok=True)
216
+
217
+ hyps = []
218
+ refs = []
219
+ audio_hyps = []
220
+
221
+ with contextlib.ExitStack() as stack:
222
+ hyp_file = stack.enter_context(
223
+ open(output_path / f"text_output-{args.data_file.stem}.txt", "w")
224
+ )
225
+ unit_file = stack.enter_context(
226
+ open(output_path / f"unit_output-{args.data_file.stem}.txt", "w")
227
+ )
228
+
229
+ sample_id = 0
230
+ for example in pipeline:
231
+ valid_sequences: Optional[Tensor] = None
232
+ src = example[args.audio_field]["data"]["fbank"]
233
+ # Skip corrupted audio tensors.
234
+ valid_sequences = ~torch.any(
235
+ torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
236
+ )
237
+ if not valid_sequences.all():
238
+ logger.warning(
239
+ f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input."
240
+ )
241
+ src["seqs"] = src["seqs"][valid_sequences]
242
+ src["seq_lens"] = src["seq_lens"][valid_sequences]
243
+
244
+ # Skip performing inference when the input is entirely corrupted.
245
+ if src["seqs"].numel() > 0:
246
+ prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
247
+ text_output, unit_output = translator.predict(
248
+ src,
249
+ args.task,
250
+ args.tgt_lang,
251
+ src_lang=args.src_lang,
252
+ text_generation_opts=text_generation_opts,
253
+ unit_generation_opts=unit_generation_opts,
254
+ unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
255
+ duration_factor=args.duration_factor,
256
+ prosody_encoder_input=prosody_encoder_input,
257
+ )
258
+
259
+ assert unit_output is not None
260
+ speech_output = pretssel_generator.predict(
261
+ unit_output.units,
262
+ tgt_lang=args.tgt_lang,
263
+ prosody_encoder_input=prosody_encoder_input,
264
+ )
265
+
266
+ else:
267
+ text_output = []
268
+ speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
269
+
270
+ if valid_sequences is not None and not valid_sequences.all():
271
+ text_output, speech_output = adjust_output_for_corrupted_inputs( # type: ignore[assignment]
272
+ valid_sequences,
273
+ text_output,
274
+ speech_output,
275
+ )
276
+
277
+ hyps += [str(s) for s in text_output]
278
+ if args.ref_field is not None and args.ref_field in example:
279
+ refs += [str(s) for s in example[args.ref_field]]
280
+
281
+ for i in range(len(text_output)):
282
+ t = text_output[i]
283
+ idx = str(example["id"][i])
284
+ hyp_file.write(f"{t}\n")
285
+
286
+ u = speech_output.units[i]
287
+ str_units = [str(i) for i in u]
288
+ unit_file.write(" ".join(str_units) + "\n")
289
+ torchaudio.save(
290
+ waveforms_dir / f"{idx}_pred.wav",
291
+ speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
292
+ sample_rate=speech_output.sample_rate,
293
+ )
294
+ audio_hyps.append((waveforms_dir / f"{idx}_pred.wav").as_posix())
295
+
296
+ sample_id += 1
297
+ progress_bar.update(1)
298
+
299
+ progress_bar.close()
300
+ logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
301
+
302
+ if args.output_result_tsv:
303
+ output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
304
+ output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
305
+ text_out = []
306
+ with open(hyp_file.name) as file:
307
+ for line in file:
308
+ text_out.append(line.strip())
309
+
310
+ unit_out = []
311
+ with open(unit_file.name) as file:
312
+ for line in file:
313
+ unit_out.append(line.strip())
314
+
315
+ output_tsv["hypo_audio"] = audio_hyps
316
+ output_tsv["s2t_out"] = text_out
317
+ output_tsv["orig_unit"] = unit_out
318
+ output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
319
+ logger.info(f"Output results in {output_tsv_file}")
320
+
321
+
322
+ if __name__ == "__main__":
323
+ main()