Spaces:
Sleeping
Sleeping
Audio-Deepfake-Detection
/
fairseq-a54021305d6b3c4c5959ac9395135f63202db8f1
/examples
/speech_to_text
/prep_mtedx_data.py
#!/usr/bin/env python3 | |
# Copyright (c) Facebook, Inc. and its affiliates. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import argparse | |
import logging | |
import os | |
from pathlib import Path | |
import shutil | |
from itertools import groupby | |
from tempfile import NamedTemporaryFile | |
from typing import Tuple | |
import pandas as pd | |
import soundfile as sf | |
from examples.speech_to_text.data_utils import ( | |
create_zip, | |
extract_fbank_features, | |
filter_manifest_df, | |
gen_config_yaml, | |
gen_vocab, | |
get_zip_manifest, | |
load_df_from_tsv, | |
save_df_to_tsv, | |
) | |
import torch | |
from torch.utils.data import Dataset | |
from tqdm import tqdm | |
from fairseq.data.audio.audio_utils import get_waveform, convert_waveform | |
log = logging.getLogger(__name__) | |
MANIFEST_COLUMNS = [ | |
"id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang" | |
] | |
class mTEDx(Dataset): | |
""" | |
Create a Dataset for Multilingual TEDx. | |
Each item is a tuple of the form: waveform, sample_rate, source utterance, | |
target utterance, speaker_id, utterance_id | |
""" | |
SPLITS = ["train", "valid", "test"] | |
LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", | |
"de-de", "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", | |
"fr-pt", "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"] | |
def __init__(self, root: str, lang: str, split: str) -> None: | |
assert split in self.SPLITS and lang in self.LANGPAIRS | |
_root = Path(root) / f"{lang}" / "data" / split | |
wav_root, txt_root = _root / "wav", _root / "txt" | |
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() | |
# Load audio segments | |
try: | |
import yaml | |
except ImportError: | |
print( | |
"Please install PyYAML to load the Multilingual TEDx YAML files" | |
) | |
with open(txt_root / f"{split}.yaml") as f: | |
segments = yaml.load(f, Loader=yaml.BaseLoader) | |
# Load source and target utterances | |
src, tgt = lang.split("-") | |
for _lang in [src, tgt]: | |
with open(txt_root / f"{split}.{_lang}") as f: | |
utterances = [r.strip() for r in f] | |
assert len(segments) == len(utterances) | |
for i, u in enumerate(utterances): | |
segments[i][_lang] = u | |
# Gather info | |
self.data = [] | |
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): | |
wav_filename = wav_filename.replace(".wav", ".flac") | |
wav_path = wav_root / wav_filename | |
sample_rate = sf.info(wav_path.as_posix()).samplerate | |
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) | |
for i, segment in enumerate(seg_group): | |
offset = int(float(segment["offset"]) * sample_rate) | |
n_frames = int(float(segment["duration"]) * sample_rate) | |
_id = f"{wav_path.stem}_{i}" | |
self.data.append( | |
( | |
wav_path.as_posix(), | |
offset, | |
n_frames, | |
sample_rate, | |
segment[src], | |
segment[tgt], | |
segment["speaker_id"], | |
tgt, | |
_id, | |
) | |
) | |
def __getitem__( | |
self, n: int | |
) -> Tuple[torch.Tensor, int, str, str, str, str, str]: | |
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, \ | |
utt_id = self.data[n] | |
waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) | |
waveform = torch.from_numpy(waveform) | |
return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id | |
def __len__(self) -> int: | |
return len(self.data) | |
def process(args): | |
root = Path(args.data_root).absolute() | |
for lang in mTEDx.LANGPAIRS: | |
cur_root = root / f"{lang}" | |
if not cur_root.is_dir(): | |
print(f"{cur_root.as_posix()} does not exist. Skipped.") | |
continue | |
# Extract features | |
audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80") | |
audio_root.mkdir(exist_ok=True) | |
for split in mTEDx.SPLITS: | |
print(f"Fetching split {split}...") | |
dataset = mTEDx(root.as_posix(), lang, split) | |
if args.use_audio_input: | |
print("Converting audios...") | |
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): | |
tgt_sample_rate = 16_000 | |
_wavform, _ = convert_waveform( | |
waveform, sample_rate, to_mono=True, | |
to_sample_rate=tgt_sample_rate | |
) | |
sf.write( | |
(audio_root / f"{utt_id}.flac").as_posix(), | |
_wavform.numpy(), tgt_sample_rate | |
) | |
else: | |
print("Extracting log mel filter bank features...") | |
for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset): | |
extract_fbank_features( | |
waveform, sample_rate, audio_root / f"{utt_id}.npy" | |
) | |
# Pack features into ZIP | |
zip_path = cur_root / f"{audio_root.name}.zip" | |
print("ZIPing audios/features...") | |
create_zip(audio_root, zip_path) | |
print("Fetching ZIP manifest...") | |
audio_paths, audio_lengths = get_zip_manifest(zip_path) | |
# Generate TSV manifest | |
print("Generating manifest...") | |
train_text = [] | |
for split in mTEDx.SPLITS: | |
is_train_split = split.startswith("train") | |
manifest = {c: [] for c in MANIFEST_COLUMNS} | |
ds = mTEDx(args.data_root, lang, split) | |
for _, _, src_utt, tgt_utt, spk_id, tgt_lang, utt_id in tqdm(ds): | |
manifest["id"].append(utt_id) | |
manifest["audio"].append(audio_paths[utt_id]) | |
manifest["n_frames"].append(audio_lengths[utt_id]) | |
manifest["tgt_text"].append( | |
src_utt if args.task == "asr" else tgt_utt | |
) | |
manifest["speaker"].append(spk_id) | |
manifest["tgt_lang"].append(tgt_lang) | |
if is_train_split: | |
train_text.extend(manifest["tgt_text"]) | |
df = pd.DataFrame.from_dict(manifest) | |
df = filter_manifest_df(df, is_train_split=is_train_split) | |
save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv") | |
# Generate vocab | |
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) | |
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" | |
with NamedTemporaryFile(mode="w") as f: | |
for t in train_text: | |
f.write(t + "\n") | |
gen_vocab( | |
Path(f.name), | |
cur_root / spm_filename_prefix, | |
args.vocab_type, | |
args.vocab_size, | |
) | |
# Generate config YAML | |
if args.use_audio_input: | |
gen_config_yaml( | |
cur_root, | |
spm_filename=spm_filename_prefix + ".model", | |
yaml_filename=f"config_{args.task}.yaml", | |
specaugment_policy=None, | |
extra={"use_audio_input": True} | |
) | |
else: | |
gen_config_yaml( | |
cur_root, | |
spm_filename=spm_filename_prefix + ".model", | |
yaml_filename=f"config_{args.task}.yaml", | |
specaugment_policy="lb", | |
) | |
# Clean up | |
shutil.rmtree(audio_root) | |
def process_joint(args): | |
cur_root = Path(args.data_root) | |
assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \ | |
"do not have downloaded data available for all languages" | |
# Generate vocab | |
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) | |
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" | |
with NamedTemporaryFile(mode="w") as f: | |
for lang in mTEDx.LANGPAIRS: | |
tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv" | |
df = load_df_from_tsv(tsv_path) | |
for t in df["tgt_text"]: | |
f.write(t + "\n") | |
special_symbols = None | |
if args.joint: | |
# Add tgt_lang tags to dict | |
special_symbols = list( | |
{f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS} | |
) | |
gen_vocab( | |
Path(f.name), | |
cur_root / spm_filename_prefix, | |
args.vocab_type, | |
args.vocab_size, | |
special_symbols=special_symbols | |
) | |
# Generate config YAML | |
gen_config_yaml( | |
cur_root, | |
spm_filename=spm_filename_prefix + ".model", | |
yaml_filename=f"config_{args.task}.yaml", | |
specaugment_policy="ld", | |
prepend_tgt_lang_tag=(args.joint), | |
) | |
# Make symbolic links to manifests | |
for lang in mTEDx.LANGPAIRS: | |
for split in mTEDx.SPLITS: | |
src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv" | |
desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv" | |
if not desc_path.is_symlink(): | |
os.symlink(src_path, desc_path) | |
def main(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument("--data-root", "-d", required=True, type=str) | |
parser.add_argument( | |
"--vocab-type", | |
default="unigram", | |
required=True, | |
type=str, | |
choices=["bpe", "unigram", "char"], | |
), | |
parser.add_argument("--vocab-size", default=8000, type=int) | |
parser.add_argument("--task", type=str, choices=["asr", "st"]) | |
parser.add_argument("--joint", action="store_true", help="") | |
parser.add_argument("--use-audio-input", action="store_true") | |
args = parser.parse_args() | |
if args.joint: | |
process_joint(args) | |
else: | |
process(args) | |
if __name__ == "__main__": | |
main() | |