Upload seamless_communication/cli/expressivity/evaluate/pretssel_inference.py with huggingface_hub
Browse files
seamless_communication/cli/expressivity/evaluate/pretssel_inference.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
import argparse
|
8 |
+
import contextlib
|
9 |
+
import logging
|
10 |
+
from argparse import Namespace
|
11 |
+
from pathlib import Path
|
12 |
+
from typing import Optional
|
13 |
+
|
14 |
+
import pandas as pd
|
15 |
+
import torch
|
16 |
+
import torchaudio
|
17 |
+
from fairseq2.data import Collater, DataPipeline, FileMapper
|
18 |
+
from fairseq2.data.audio import (
|
19 |
+
AudioDecoder,
|
20 |
+
WaveformToFbankConverter,
|
21 |
+
WaveformToFbankOutput,
|
22 |
+
)
|
23 |
+
from fairseq2.data.text import StrSplitter, read_text
|
24 |
+
from fairseq2.typing import DataType, Device
|
25 |
+
from sacrebleu.metrics import BLEU # type: ignore[attr-defined]
|
26 |
+
from torch import Tensor
|
27 |
+
from tqdm import tqdm
|
28 |
+
|
29 |
+
from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
|
30 |
+
PretsselGenerator,
|
31 |
+
)
|
32 |
+
from seamless_communication.cli.m4t.evaluate.evaluate import (
|
33 |
+
adjust_output_for_corrupted_inputs,
|
34 |
+
count_lines,
|
35 |
+
)
|
36 |
+
from seamless_communication.cli.m4t.predict import (
|
37 |
+
add_inference_arguments,
|
38 |
+
set_generation_opts,
|
39 |
+
)
|
40 |
+
from seamless_communication.inference import BatchedSpeechOutput, Translator
|
41 |
+
from seamless_communication.models.unity import (
|
42 |
+
load_gcmvn_stats,
|
43 |
+
load_unity_unit_tokenizer,
|
44 |
+
)
|
45 |
+
from seamless_communication.store import add_gated_assets
|
46 |
+
|
47 |
+
logging.basicConfig(
|
48 |
+
level=logging.INFO,
|
49 |
+
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
50 |
+
)
|
51 |
+
|
52 |
+
logger = logging.getLogger(__name__)
|
53 |
+
|
54 |
+
|
55 |
+
def build_data_pipeline(
|
56 |
+
args: Namespace,
|
57 |
+
device: Device,
|
58 |
+
dtype: DataType,
|
59 |
+
gcmvn_mean: Tensor,
|
60 |
+
gcmvn_std: Tensor,
|
61 |
+
) -> DataPipeline:
|
62 |
+
with open(args.data_file, "r") as f:
|
63 |
+
header = f.readline().strip("\n").split("\t")
|
64 |
+
assert (
|
65 |
+
args.audio_field in header
|
66 |
+
), f"Input file does not contain {args.audio_field} field"
|
67 |
+
|
68 |
+
n_parallel = 4
|
69 |
+
|
70 |
+
split_tsv = StrSplitter(names=header)
|
71 |
+
|
72 |
+
pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
|
73 |
+
|
74 |
+
assert args.audio_root_dir is not None
|
75 |
+
|
76 |
+
map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
|
77 |
+
|
78 |
+
pipeline_builder.map(
|
79 |
+
map_file, selector=args.audio_field, num_parallel_calls=n_parallel
|
80 |
+
)
|
81 |
+
|
82 |
+
decode_audio = AudioDecoder(dtype=torch.float32, device=device)
|
83 |
+
|
84 |
+
convert_to_fbank = WaveformToFbankConverter(
|
85 |
+
num_mel_bins=80,
|
86 |
+
waveform_scale=2**15,
|
87 |
+
channel_last=True,
|
88 |
+
standardize=False,
|
89 |
+
device=device,
|
90 |
+
dtype=dtype,
|
91 |
+
)
|
92 |
+
|
93 |
+
def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
|
94 |
+
fbank = data["fbank"]
|
95 |
+
std, mean = torch.std_mean(fbank, dim=0)
|
96 |
+
data["fbank"] = fbank.subtract(mean).divide(std)
|
97 |
+
data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
|
98 |
+
return data
|
99 |
+
|
100 |
+
pipeline_builder.map(
|
101 |
+
[decode_audio, convert_to_fbank, normalize_fbank],
|
102 |
+
selector=f"{args.audio_field}.data",
|
103 |
+
num_parallel_calls=n_parallel,
|
104 |
+
)
|
105 |
+
|
106 |
+
pipeline_builder.bucket(bucket_size=args.batch_size)
|
107 |
+
|
108 |
+
collate = Collater(pad_value=0, pad_to_multiple=1)
|
109 |
+
|
110 |
+
pipeline_builder.map(collate, num_parallel_calls=n_parallel)
|
111 |
+
|
112 |
+
pipeline_builder.prefetch(4)
|
113 |
+
|
114 |
+
return pipeline_builder.and_return()
|
115 |
+
|
116 |
+
|
117 |
+
def main() -> None:
|
118 |
+
parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference")
|
119 |
+
parser.add_argument(
|
120 |
+
"data_file", type=Path, help="Data file (.tsv) to be evaluated."
|
121 |
+
)
|
122 |
+
|
123 |
+
parser = add_inference_arguments(parser)
|
124 |
+
param = parser.add_argument(
|
125 |
+
"--gated-model-dir",
|
126 |
+
type=Path,
|
127 |
+
required=False,
|
128 |
+
help="SeamlessExpressive model directory.",
|
129 |
+
)
|
130 |
+
parser.add_argument(
|
131 |
+
"--batch_size",
|
132 |
+
type=int,
|
133 |
+
help="Inference batch size.",
|
134 |
+
default=4,
|
135 |
+
)
|
136 |
+
parser.add_argument(
|
137 |
+
"--audio_root_dir",
|
138 |
+
type=Path,
|
139 |
+
help="Root directory for the audio filenames in the data file.",
|
140 |
+
default="",
|
141 |
+
)
|
142 |
+
parser.add_argument(
|
143 |
+
"--audio_field",
|
144 |
+
type=str,
|
145 |
+
help="Field that includes the input audio file paths.",
|
146 |
+
default="src_audio",
|
147 |
+
)
|
148 |
+
parser.add_argument(
|
149 |
+
"--ref_field",
|
150 |
+
type=str,
|
151 |
+
help="Reference target text field to compute the BLEU score against.",
|
152 |
+
default=None,
|
153 |
+
)
|
154 |
+
parser.add_argument(
|
155 |
+
"--duration_factor",
|
156 |
+
type=float,
|
157 |
+
help="The duration factor for NAR T2U model.",
|
158 |
+
default=1.0,
|
159 |
+
)
|
160 |
+
parser.add_argument(
|
161 |
+
"--output_result_tsv",
|
162 |
+
type=bool,
|
163 |
+
help="Whether to output results in tsv format (for full-blown evaluation)",
|
164 |
+
default=True,
|
165 |
+
)
|
166 |
+
args = parser.parse_args()
|
167 |
+
|
168 |
+
if args.gated_model_dir:
|
169 |
+
add_gated_assets(args.gated_model_dir)
|
170 |
+
|
171 |
+
if torch.cuda.is_available():
|
172 |
+
device = torch.device("cuda:0")
|
173 |
+
dtype = torch.float16
|
174 |
+
else:
|
175 |
+
device = torch.device("cpu")
|
176 |
+
dtype = torch.float32
|
177 |
+
|
178 |
+
unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
|
179 |
+
|
180 |
+
_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
|
181 |
+
gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
|
182 |
+
gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
|
183 |
+
|
184 |
+
pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
|
185 |
+
|
186 |
+
translator = Translator(
|
187 |
+
args.model_name,
|
188 |
+
vocoder_name_or_card=None,
|
189 |
+
device=device,
|
190 |
+
dtype=dtype,
|
191 |
+
)
|
192 |
+
|
193 |
+
text_generation_opts, unit_generation_opts = set_generation_opts(args)
|
194 |
+
|
195 |
+
logger.info(f"{text_generation_opts=}")
|
196 |
+
logger.info(f"{unit_generation_opts=}")
|
197 |
+
logger.info(
|
198 |
+
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
|
199 |
+
)
|
200 |
+
|
201 |
+
pretssel_generator = PretsselGenerator(
|
202 |
+
args.vocoder_name,
|
203 |
+
vocab_info=unit_tokenizer.vocab_info,
|
204 |
+
device=device,
|
205 |
+
dtype=dtype,
|
206 |
+
)
|
207 |
+
|
208 |
+
total_steps = count_lines(args.data_file) - 1
|
209 |
+
progress_bar = tqdm(total=total_steps)
|
210 |
+
|
211 |
+
output_path = args.output_path / args.data_file.stem
|
212 |
+
output_path.mkdir(parents=True, exist_ok=True)
|
213 |
+
|
214 |
+
waveforms_dir = output_path / "waveform"
|
215 |
+
waveforms_dir.mkdir(parents=True, exist_ok=True)
|
216 |
+
|
217 |
+
hyps = []
|
218 |
+
refs = []
|
219 |
+
audio_hyps = []
|
220 |
+
|
221 |
+
with contextlib.ExitStack() as stack:
|
222 |
+
hyp_file = stack.enter_context(
|
223 |
+
open(output_path / f"text_output-{args.data_file.stem}.txt", "w")
|
224 |
+
)
|
225 |
+
unit_file = stack.enter_context(
|
226 |
+
open(output_path / f"unit_output-{args.data_file.stem}.txt", "w")
|
227 |
+
)
|
228 |
+
|
229 |
+
sample_id = 0
|
230 |
+
for example in pipeline:
|
231 |
+
valid_sequences: Optional[Tensor] = None
|
232 |
+
src = example[args.audio_field]["data"]["fbank"]
|
233 |
+
# Skip corrupted audio tensors.
|
234 |
+
valid_sequences = ~torch.any(
|
235 |
+
torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
|
236 |
+
)
|
237 |
+
if not valid_sequences.all():
|
238 |
+
logger.warning(
|
239 |
+
f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input."
|
240 |
+
)
|
241 |
+
src["seqs"] = src["seqs"][valid_sequences]
|
242 |
+
src["seq_lens"] = src["seq_lens"][valid_sequences]
|
243 |
+
|
244 |
+
# Skip performing inference when the input is entirely corrupted.
|
245 |
+
if src["seqs"].numel() > 0:
|
246 |
+
prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
|
247 |
+
text_output, unit_output = translator.predict(
|
248 |
+
src,
|
249 |
+
args.task,
|
250 |
+
args.tgt_lang,
|
251 |
+
src_lang=args.src_lang,
|
252 |
+
text_generation_opts=text_generation_opts,
|
253 |
+
unit_generation_opts=unit_generation_opts,
|
254 |
+
unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
|
255 |
+
duration_factor=args.duration_factor,
|
256 |
+
prosody_encoder_input=prosody_encoder_input,
|
257 |
+
)
|
258 |
+
|
259 |
+
assert unit_output is not None
|
260 |
+
speech_output = pretssel_generator.predict(
|
261 |
+
unit_output.units,
|
262 |
+
tgt_lang=args.tgt_lang,
|
263 |
+
prosody_encoder_input=prosody_encoder_input,
|
264 |
+
)
|
265 |
+
|
266 |
+
else:
|
267 |
+
text_output = []
|
268 |
+
speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
|
269 |
+
|
270 |
+
if valid_sequences is not None and not valid_sequences.all():
|
271 |
+
text_output, speech_output = adjust_output_for_corrupted_inputs( # type: ignore[assignment]
|
272 |
+
valid_sequences,
|
273 |
+
text_output,
|
274 |
+
speech_output,
|
275 |
+
)
|
276 |
+
|
277 |
+
hyps += [str(s) for s in text_output]
|
278 |
+
if args.ref_field is not None and args.ref_field in example:
|
279 |
+
refs += [str(s) for s in example[args.ref_field]]
|
280 |
+
|
281 |
+
for i in range(len(text_output)):
|
282 |
+
t = text_output[i]
|
283 |
+
idx = str(example["id"][i])
|
284 |
+
hyp_file.write(f"{t}\n")
|
285 |
+
|
286 |
+
u = speech_output.units[i]
|
287 |
+
str_units = [str(i) for i in u]
|
288 |
+
unit_file.write(" ".join(str_units) + "\n")
|
289 |
+
torchaudio.save(
|
290 |
+
waveforms_dir / f"{idx}_pred.wav",
|
291 |
+
speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
|
292 |
+
sample_rate=speech_output.sample_rate,
|
293 |
+
)
|
294 |
+
audio_hyps.append((waveforms_dir / f"{idx}_pred.wav").as_posix())
|
295 |
+
|
296 |
+
sample_id += 1
|
297 |
+
progress_bar.update(1)
|
298 |
+
|
299 |
+
progress_bar.close()
|
300 |
+
logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
|
301 |
+
|
302 |
+
if args.output_result_tsv:
|
303 |
+
output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
|
304 |
+
output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
|
305 |
+
text_out = []
|
306 |
+
with open(hyp_file.name) as file:
|
307 |
+
for line in file:
|
308 |
+
text_out.append(line.strip())
|
309 |
+
|
310 |
+
unit_out = []
|
311 |
+
with open(unit_file.name) as file:
|
312 |
+
for line in file:
|
313 |
+
unit_out.append(line.strip())
|
314 |
+
|
315 |
+
output_tsv["hypo_audio"] = audio_hyps
|
316 |
+
output_tsv["s2t_out"] = text_out
|
317 |
+
output_tsv["orig_unit"] = unit_out
|
318 |
+
output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
|
319 |
+
logger.info(f"Output results in {output_tsv_file}")
|
320 |
+
|
321 |
+
|
322 |
+
if __name__ == "__main__":
|
323 |
+
main()
|