|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import contextlib |
|
import logging |
|
from argparse import Namespace |
|
from pathlib import Path |
|
from typing import Optional |
|
|
|
import pandas as pd |
|
import torch |
|
import torchaudio |
|
from fairseq2.data import Collater, DataPipeline, FileMapper |
|
from fairseq2.data.audio import ( |
|
AudioDecoder, |
|
WaveformToFbankConverter, |
|
WaveformToFbankOutput, |
|
) |
|
from fairseq2.data.text import StrSplitter, read_text |
|
from fairseq2.typing import DataType, Device |
|
from sacrebleu.metrics import BLEU |
|
from torch import Tensor |
|
from tqdm import tqdm |
|
|
|
from seamless_communication.cli.expressivity.evaluate.pretssel_inference_helper import ( |
|
PretsselGenerator, |
|
) |
|
from seamless_communication.cli.m4t.evaluate.evaluate import ( |
|
adjust_output_for_corrupted_inputs, |
|
count_lines, |
|
) |
|
from seamless_communication.cli.m4t.predict import ( |
|
add_inference_arguments, |
|
set_generation_opts, |
|
) |
|
from seamless_communication.inference import BatchedSpeechOutput, Translator |
|
from seamless_communication.models.unity import ( |
|
load_gcmvn_stats, |
|
load_unity_unit_tokenizer, |
|
) |
|
from seamless_communication.store import add_gated_assets |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", |
|
) |
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def build_data_pipeline( |
|
args: Namespace, |
|
device: Device, |
|
dtype: DataType, |
|
gcmvn_mean: Tensor, |
|
gcmvn_std: Tensor, |
|
) -> DataPipeline: |
|
with open(args.data_file, "r") as f: |
|
header = f.readline().strip("\n").split("\t") |
|
assert ( |
|
args.audio_field in header |
|
), f"Input file does not contain {args.audio_field} field" |
|
|
|
n_parallel = 4 |
|
|
|
split_tsv = StrSplitter(names=header) |
|
|
|
pipeline_builder = read_text(args.data_file, rtrim=True).skip(1).map(split_tsv) |
|
|
|
assert args.audio_root_dir is not None |
|
|
|
map_file = FileMapper(root_dir=args.audio_root_dir, cached_fd_count=10) |
|
|
|
pipeline_builder.map( |
|
map_file, selector=args.audio_field, num_parallel_calls=n_parallel |
|
) |
|
|
|
decode_audio = AudioDecoder(dtype=torch.float32, device=device) |
|
|
|
convert_to_fbank = WaveformToFbankConverter( |
|
num_mel_bins=80, |
|
waveform_scale=2**15, |
|
channel_last=True, |
|
standardize=False, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
def normalize_fbank(data: WaveformToFbankOutput) -> WaveformToFbankOutput: |
|
fbank = data["fbank"] |
|
std, mean = torch.std_mean(fbank, dim=0) |
|
data["fbank"] = fbank.subtract(mean).divide(std) |
|
data["gcmvn_fbank"] = fbank.subtract(gcmvn_mean).divide(gcmvn_std) |
|
return data |
|
|
|
pipeline_builder.map( |
|
[decode_audio, convert_to_fbank, normalize_fbank], |
|
selector=f"{args.audio_field}.data", |
|
num_parallel_calls=n_parallel, |
|
) |
|
|
|
pipeline_builder.bucket(bucket_size=args.batch_size) |
|
|
|
collate = Collater(pad_value=0, pad_to_multiple=1) |
|
|
|
pipeline_builder.map(collate, num_parallel_calls=n_parallel) |
|
|
|
pipeline_builder.prefetch(4) |
|
|
|
return pipeline_builder.and_return() |
|
|
|
|
|
def main() -> None: |
|
parser = argparse.ArgumentParser(description="Running SeamlessExpressive inference") |
|
parser.add_argument( |
|
"data_file", type=Path, help="Data file (.tsv) to be evaluated." |
|
) |
|
|
|
parser = add_inference_arguments(parser) |
|
param = parser.add_argument( |
|
"--gated-model-dir", |
|
type=Path, |
|
required=False, |
|
help="SeamlessExpressive model directory.", |
|
) |
|
parser.add_argument( |
|
"--batch_size", |
|
type=int, |
|
help="Inference batch size.", |
|
default=4, |
|
) |
|
parser.add_argument( |
|
"--audio_root_dir", |
|
type=Path, |
|
help="Root directory for the audio filenames in the data file.", |
|
default="", |
|
) |
|
parser.add_argument( |
|
"--audio_field", |
|
type=str, |
|
help="Field that includes the input audio file paths.", |
|
default="src_audio", |
|
) |
|
parser.add_argument( |
|
"--ref_field", |
|
type=str, |
|
help="Reference target text field to compute the BLEU score against.", |
|
default=None, |
|
) |
|
parser.add_argument( |
|
"--duration_factor", |
|
type=float, |
|
help="The duration factor for NAR T2U model.", |
|
default=1.0, |
|
) |
|
parser.add_argument( |
|
"--output_result_tsv", |
|
type=bool, |
|
help="Whether to output results in tsv format (for full-blown evaluation)", |
|
default=True, |
|
) |
|
args = parser.parse_args() |
|
|
|
if args.gated_model_dir: |
|
add_gated_assets(args.gated_model_dir) |
|
|
|
if torch.cuda.is_available(): |
|
device = torch.device("cuda:0") |
|
dtype = torch.float16 |
|
else: |
|
device = torch.device("cpu") |
|
dtype = torch.float32 |
|
|
|
unit_tokenizer = load_unity_unit_tokenizer(args.model_name) |
|
|
|
_gcmvn_mean, _gcmvn_std = load_gcmvn_stats(args.vocoder_name) |
|
gcmvn_mean = torch.tensor(_gcmvn_mean, device=device, dtype=dtype) |
|
gcmvn_std = torch.tensor(_gcmvn_std, device=device, dtype=dtype) |
|
|
|
pipeline = build_data_pipeline(args, device, dtype, gcmvn_mean, gcmvn_std) |
|
|
|
translator = Translator( |
|
args.model_name, |
|
vocoder_name_or_card=None, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
text_generation_opts, unit_generation_opts = set_generation_opts(args) |
|
|
|
logger.info(f"{text_generation_opts=}") |
|
logger.info(f"{unit_generation_opts=}") |
|
logger.info( |
|
f"unit_generation_ngram_filtering={args.unit_generation_ngram_filtering}" |
|
) |
|
|
|
pretssel_generator = PretsselGenerator( |
|
args.vocoder_name, |
|
vocab_info=unit_tokenizer.vocab_info, |
|
device=device, |
|
dtype=dtype, |
|
) |
|
|
|
total_steps = count_lines(args.data_file) - 1 |
|
progress_bar = tqdm(total=total_steps) |
|
|
|
output_path = args.output_path / args.data_file.stem |
|
output_path.mkdir(parents=True, exist_ok=True) |
|
|
|
waveforms_dir = output_path / "waveform" |
|
waveforms_dir.mkdir(parents=True, exist_ok=True) |
|
|
|
hyps = [] |
|
refs = [] |
|
audio_hyps = [] |
|
|
|
with contextlib.ExitStack() as stack: |
|
hyp_file = stack.enter_context( |
|
open(output_path / f"text_output-{args.data_file.stem}.txt", "w") |
|
) |
|
unit_file = stack.enter_context( |
|
open(output_path / f"unit_output-{args.data_file.stem}.txt", "w") |
|
) |
|
|
|
sample_id = 0 |
|
for example in pipeline: |
|
valid_sequences: Optional[Tensor] = None |
|
src = example[args.audio_field]["data"]["fbank"] |
|
|
|
valid_sequences = ~torch.any( |
|
torch.any(torch.isnan(src["seqs"]), dim=1), dim=1 |
|
) |
|
if not valid_sequences.all(): |
|
logger.warning( |
|
f"Sample IDs {sample_id} to {sample_id + args.batch_size} has some corrupted input." |
|
) |
|
src["seqs"] = src["seqs"][valid_sequences] |
|
src["seq_lens"] = src["seq_lens"][valid_sequences] |
|
|
|
|
|
if src["seqs"].numel() > 0: |
|
prosody_encoder_input = example[args.audio_field]["data"]["gcmvn_fbank"] |
|
text_output, unit_output = translator.predict( |
|
src, |
|
args.task, |
|
args.tgt_lang, |
|
src_lang=args.src_lang, |
|
text_generation_opts=text_generation_opts, |
|
unit_generation_opts=unit_generation_opts, |
|
unit_generation_ngram_filtering=args.unit_generation_ngram_filtering, |
|
duration_factor=args.duration_factor, |
|
prosody_encoder_input=prosody_encoder_input, |
|
) |
|
|
|
assert unit_output is not None |
|
speech_output = pretssel_generator.predict( |
|
unit_output.units, |
|
tgt_lang=args.tgt_lang, |
|
prosody_encoder_input=prosody_encoder_input, |
|
) |
|
|
|
else: |
|
text_output = [] |
|
speech_output = BatchedSpeechOutput(units=[], audio_wavs=[]) |
|
|
|
if valid_sequences is not None and not valid_sequences.all(): |
|
text_output, speech_output = adjust_output_for_corrupted_inputs( |
|
valid_sequences, |
|
text_output, |
|
speech_output, |
|
) |
|
|
|
hyps += [str(s) for s in text_output] |
|
if args.ref_field is not None and args.ref_field in example: |
|
refs += [str(s) for s in example[args.ref_field]] |
|
|
|
for i in range(len(text_output)): |
|
t = text_output[i] |
|
idx = str(example["id"][i]) |
|
hyp_file.write(f"{t}\n") |
|
|
|
u = speech_output.units[i] |
|
str_units = [str(i) for i in u] |
|
unit_file.write(" ".join(str_units) + "\n") |
|
torchaudio.save( |
|
waveforms_dir / f"{idx}_pred.wav", |
|
speech_output.audio_wavs[i][0].to(torch.float32).cpu(), |
|
sample_rate=speech_output.sample_rate, |
|
) |
|
audio_hyps.append((waveforms_dir / f"{idx}_pred.wav").as_posix()) |
|
|
|
sample_id += 1 |
|
progress_bar.update(1) |
|
|
|
progress_bar.close() |
|
logger.info(f"Processed {len(hyps)} hyps, {len(refs)} refs") |
|
|
|
if args.output_result_tsv: |
|
output_tsv_file = output_path / f"generate-{args.data_file.stem}.tsv" |
|
output_tsv = pd.read_csv(args.data_file, quoting=3, sep="\t") |
|
text_out = [] |
|
with open(hyp_file.name) as file: |
|
for line in file: |
|
text_out.append(line.strip()) |
|
|
|
unit_out = [] |
|
with open(unit_file.name) as file: |
|
for line in file: |
|
unit_out.append(line.strip()) |
|
|
|
output_tsv["hypo_audio"] = audio_hyps |
|
output_tsv["s2t_out"] = text_out |
|
output_tsv["orig_unit"] = unit_out |
|
output_tsv.to_csv(output_tsv_file, quoting=3, sep="\t", index=False) |
|
logger.info(f"Output results in {output_tsv_file}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|