victan's picture
Upload seamless_communication/cli/expressivity/evaluate/pretssel_inference.py with huggingface_hub
7933050
raw
history blame
No virus
10.5 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
import argparse
import contextlib
import logging
from argparse import Namespace
from pathlib import Path
from typing import Optional
import pandas as pd
import torch
import torchaudio
from fairseq2.data import Collater, DataPipeline, FileMapper
from fairseq2.data.audio import (
AudioDecoder,
WaveformToFbankConverter,
WaveformToFbankOutput,
)
from fairseq2.data.text import StrSplitter, read_text
from fairseq2.typing import DataType, Device
from sacrebleu.metrics import BLEU # type: ignore[attr-defined]
from torch import Tensor
from tqdm import tqdm
from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import (
PretsselGenerator,
)
from seamless_communication.cli.m4t.evaluate.evaluate import (
adjust_output_for_corrupted_inputs,
count_lines,
)
from seamless_communication.cli.m4t.predict import (
add_inference_arguments,
set_generation_opts,
)
from seamless_communication.inference import BatchedSpeechOutput, Translator
from seamless_communication.models.unity import (
load_gcmvn_stats,
load_unity_unit_tokenizer,
)
from seamless_communication.store import add_gated_assets
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
)
logger = logging.getLogger(__name__)
def build_data_pipeline(
args: Namespace,
device: Device,
dtype: DataType,
gcmvn_mean: Tensor,
gcmvn_std: Tensor,
) -> DataPipeline:
with open(args.data_file, "r") as f:
header = f.readline().strip("\n").split("\t")
assert (
args.audio_field in header
), f"Input file does not contain {args.audio_field} field"
n_parallel = 4
split_tsv = StrSplitter(names=header)
pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv)
assert args.audio_root_dir is not None
map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10)
pipeline_builder.map(
map_file, selector=args.audio_field, num_parallel_calls=n_parallel
)
decode_audio = AudioDecoder(dtype=torch.float32, device=device)
convert_to_fbank = WaveformToFbankConverter(
num_mel_bins=80,
waveform_scale=2**15,
channel_last=True,
standardize=False,
device=device,
dtype=dtype,
)
def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput:
fbank = data["fbank"]
std, mean = torch.std_mean(fbank, dim=0)
data["fbank"] = fbank.subtract(mean).divide(std)
data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std)
return data
pipeline_builder.map(
[decode_audio, convert_to_fbank, normalize_fbank],
selector=f"{args.audio_field}.data",
num_parallel_calls=n_parallel,
)
pipeline_builder.bucket(bucket_size=args.batch_size)
collate = Collater(pad_value=0, pad_to_multiple=1)
pipeline_builder.map(collate, num_parallel_calls=n_parallel)
pipeline_builder.prefetch(4)
return pipeline_builder.and_return()
def main() -> None:
parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference")
parser.add_argument(
"data_file", type=Path, help="Data file (.tsv) to be evaluated."
)
parser = add_inference_arguments(parser)
param = parser.add_argument(
"--gated-model-dir",
type=Path,
required=False,
help="SeamlessExpressive model directory.",
)
parser.add_argument(
"--batch_size",
type=int,
help="Inference batch size.",
default=4,
)
parser.add_argument(
"--audio_root_dir",
type=Path,
help="Root directory for the audio filenames in the data file.",
default="",
)
parser.add_argument(
"--audio_field",
type=str,
help="Field that includes the input audio file paths.",
default="src_audio",
)
parser.add_argument(
"--ref_field",
type=str,
help="Reference target text field to compute the BLEU score against.",
default=None,
)
parser.add_argument(
"--duration_factor",
type=float,
help="The duration factor for NAR T2U model.",
default=1.0,
)
parser.add_argument(
"--output_result_tsv",
type=bool,
help="Whether to output results in tsv format (for full-blown evaluation)",
default=True,
)
args = parser.parse_args()
if args.gated_model_dir:
add_gated_assets(args.gated_model_dir)
if torch.cuda.is_available():
device = torch.device("cuda:0")
dtype = torch.float16
else:
device = torch.device("cpu")
dtype = torch.float32
unit_tokenizer = load_unity_unit_tokenizer(args.model_name)
_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name)
gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype)
gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype)
pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std)
translator = Translator(
args.model_name,
vocoder_name_or_card=None,
device=device,
dtype=dtype,
)
text_generation_opts, unit_generation_opts = set_generation_opts(args)
logger.info(f"{text_generation_opts=}")
logger.info(f"{unit_generation_opts=}")
logger.info(
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}"
)
pretssel_generator = PretsselGenerator(
args.vocoder_name,
vocab_info=unit_tokenizer.vocab_info,
device=device,
dtype=dtype,
)
total_steps = count_lines(args.data_file) - 1
progress_bar = tqdm(total=total_steps)
output_path = args.output_path / args.data_file.stem
output_path.mkdir(parents=True, exist_ok=True)
waveforms_dir = output_path / "waveform"
waveforms_dir.mkdir(parents=True, exist_ok=True)
hyps = []
refs = []
audio_hyps = []
with contextlib.ExitStack() as stack:
hyp_file = stack.enter_context(
open(output_path / f"text_output-{args.data_file.stem}.txt", "w")
)
unit_file = stack.enter_context(
open(output_path / f"unit_output-{args.data_file.stem}.txt", "w")
)
sample_id = 0
for example in pipeline:
valid_sequences: Optional[Tensor] = None
src = example[args.audio_field]["data"]["fbank"]
# Skip corrupted audio tensors.
valid_sequences = ~torch.any(
torch.any(torch.isnan(src["seqs"]), dim=1), dim=1
)
if not valid_sequences.all():
logger.warning(
f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input."
)
src["seqs"] = src["seqs"][valid_sequences]
src["seq_lens"] = src["seq_lens"][valid_sequences]
# Skip performing inference when the input is entirely corrupted.
if src["seqs"].numel() > 0:
prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"]
text_output, unit_output = translator.predict(
src,
args.task,
args.tgt_lang,
src_lang=args.src_lang,
text_generation_opts=text_generation_opts,
unit_generation_opts=unit_generation_opts,
unit_generation_ngram_filtering=args.unit_generation_ngram_filtering,
duration_factor=args.duration_factor,
prosody_encoder_input=prosody_encoder_input,
)
assert unit_output is not None
speech_output = pretssel_generator.predict(
unit_output.units,
tgt_lang=args.tgt_lang,
prosody_encoder_input=prosody_encoder_input,
)
else:
text_output = []
speech_output = BatchedSpeechOutput(units=[], audio_wavs=[])
if valid_sequences is not None and not valid_sequences.all():
text_output, speech_output = adjust_output_for_corrupted_inputs( # type: ignore[assignment]
valid_sequences,
text_output,
speech_output,
)
hyps += [str(s) for s in text_output]
if args.ref_field is not None and args.ref_field in example:
refs += [str(s) for s in example[args.ref_field]]
for i in range(len(text_output)):
t = text_output[i]
idx = str(example["id"][i])
hyp_file.write(f"{t}\n")
u = speech_output.units[i]
str_units = [str(i) for i in u]
unit_file.write(" ".join(str_units) + "\n")
torchaudio.save(
waveforms_dir / f"{idx}_pred.wav",
speech_output.audio_wavs[i][0].to(torch.float32).cpu(),
sample_rate=speech_output.sample_rate,
)
audio_hyps.append((waveforms_dir / f"{idx}_pred.wav").as_posix())
sample_id += 1
progress_bar.update(1)
progress_bar.close()
logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs")
if args.output_result_tsv:
output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv"
output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t")
text_out = []
with open(hyp_file.name) as file:
for line in file:
text_out.append(line.strip())
unit_out = []
with open(unit_file.name) as file:
for line in file:
unit_out.append(line.strip())
output_tsv["hypo_audio"] = audio_hyps
output_tsv["s2t_out"] = text_out
output_tsv["orig_unit"] = unit_out
output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False)
logger.info(f"Output results in {output_tsv_file}")
if __name__ == "__main__":
main()