|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import logging |
|
import os |
|
from pathlib import Path |
|
import shutil |
|
from itertools import groupby |
|
from tempfile import NamedTemporaryFile |
|
from typing import Tuple |
|
|
|
import pandas as pd |
|
import soundfile as sf |
|
from examples.speech_to_text.data_utils import ( |
|
create_zip, |
|
extract_fbank_features, |
|
filter_manifest_df, |
|
gen_config_yaml, |
|
gen_vocab, |
|
get_zip_manifest, |
|
load_df_from_tsv, |
|
save_df_to_tsv, |
|
) |
|
import torch |
|
from torch.utils.data import Dataset |
|
from tqdm import tqdm |
|
|
|
from fairseq.data.audio.audio_utils import get_waveform, convert_waveform |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
MANIFEST_COLUMNS = [ |
|
"id", "audio", "n_frames", "tgt_text", "speaker", "tgt_lang" |
|
] |
|
|
|
|
|
class mTEDx(Dataset): |
|
""" |
|
Create a Dataset for Multilingual TEDx. |
|
Each item is a tuple of the form: waveform, sample_rate, source utterance, |
|
target utterance, speaker_id, utterance_id |
|
""" |
|
|
|
SPLITS = ["train", "valid", "test"] |
|
LANGPAIRS = ["es-es", "fr-fr", "pt-pt", "it-it", "ru-ru", "el-el", "ar-ar", |
|
"de-de", "es-en", "es-fr", "es-pt", "es-it", "fr-en", "fr-es", |
|
"fr-pt", "pt-en", "pt-es", "it-en", "it-es", "ru-en", "el-en"] |
|
|
|
def __init__(self, root: str, lang: str, split: str) -> None: |
|
assert split in self.SPLITS and lang in self.LANGPAIRS |
|
_root = Path(root) / f"{lang}" / "data" / split |
|
wav_root, txt_root = _root / "wav", _root / "txt" |
|
assert _root.is_dir() and wav_root.is_dir() and txt_root.is_dir() |
|
|
|
try: |
|
import yaml |
|
except ImportError: |
|
print( |
|
"Please install PyYAML to load the Multilingual TEDx YAML files" |
|
) |
|
with open(txt_root / f"{split}.yaml") as f: |
|
segments = yaml.load(f, Loader=yaml.BaseLoader) |
|
|
|
src, tgt = lang.split("-") |
|
for _lang in [src, tgt]: |
|
with open(txt_root / f"{split}.{_lang}") as f: |
|
utterances = [r.strip() for r in f] |
|
assert len(segments) == len(utterances) |
|
for i, u in enumerate(utterances): |
|
segments[i][_lang] = u |
|
|
|
self.data = [] |
|
for wav_filename, _seg_group in groupby(segments, lambda x: x["wav"]): |
|
wav_filename = wav_filename.replace(".wav", ".flac") |
|
wav_path = wav_root / wav_filename |
|
sample_rate = sf.info(wav_path.as_posix()).samplerate |
|
seg_group = sorted(_seg_group, key=lambda x: float(x["offset"])) |
|
for i, segment in enumerate(seg_group): |
|
offset = int(float(segment["offset"]) * sample_rate) |
|
n_frames = int(float(segment["duration"]) * sample_rate) |
|
_id = f"{wav_path.stem}_{i}" |
|
self.data.append( |
|
( |
|
wav_path.as_posix(), |
|
offset, |
|
n_frames, |
|
sample_rate, |
|
segment[src], |
|
segment[tgt], |
|
segment["speaker_id"], |
|
tgt, |
|
_id, |
|
) |
|
) |
|
|
|
def __getitem__( |
|
self, n: int |
|
) -> Tuple[torch.Tensor, int, str, str, str, str, str]: |
|
wav_path, offset, n_frames, sr, src_utt, tgt_utt, spk_id, tgt_lang, \ |
|
utt_id = self.data[n] |
|
waveform, _ = get_waveform(wav_path, frames=n_frames, start=offset) |
|
waveform = torch.from_numpy(waveform) |
|
return waveform, sr, src_utt, tgt_utt, spk_id, tgt_lang, utt_id |
|
|
|
def __len__(self) -> int: |
|
return len(self.data) |
|
|
|
|
|
def process(args): |
|
root = Path(args.data_root).absolute() |
|
for lang in mTEDx.LANGPAIRS: |
|
cur_root = root / f"{lang}" |
|
if not cur_root.is_dir(): |
|
print(f"{cur_root.as_posix()} does not exist. Skipped.") |
|
continue |
|
|
|
audio_root = cur_root / ("flac" if args.use_audio_input else "fbank80") |
|
audio_root.mkdir(exist_ok=True) |
|
for split in mTEDx.SPLITS: |
|
print(f"Fetching split {split}...") |
|
dataset = mTEDx(root.as_posix(), lang, split) |
|
if args.use_audio_input: |
|
print("Converting audios...") |
|
for waveform, sample_rate, _, _, _, utt_id in tqdm(dataset): |
|
tgt_sample_rate = 16_000 |
|
_wavform, _ = convert_waveform( |
|
waveform, sample_rate, to_mono=True, |
|
to_sample_rate=tgt_sample_rate |
|
) |
|
sf.write( |
|
(audio_root / f"{utt_id}.flac").as_posix(), |
|
_wavform.numpy(), tgt_sample_rate |
|
) |
|
else: |
|
print("Extracting log mel filter bank features...") |
|
for waveform, sample_rate, _, _, _, _, utt_id in tqdm(dataset): |
|
extract_fbank_features( |
|
waveform, sample_rate, audio_root / f"{utt_id}.npy" |
|
) |
|
|
|
zip_path = cur_root / f"{audio_root.name}.zip" |
|
print("ZIPing audios/features...") |
|
create_zip(audio_root, zip_path) |
|
print("Fetching ZIP manifest...") |
|
audio_paths, audio_lengths = get_zip_manifest(zip_path) |
|
|
|
print("Generating manifest...") |
|
train_text = [] |
|
for split in mTEDx.SPLITS: |
|
is_train_split = split.startswith("train") |
|
manifest = {c: [] for c in MANIFEST_COLUMNS} |
|
ds = mTEDx(args.data_root, lang, split) |
|
for _, _, src_utt, tgt_utt, spk_id, tgt_lang, utt_id in tqdm(ds): |
|
manifest["id"].append(utt_id) |
|
manifest["audio"].append(audio_paths[utt_id]) |
|
manifest["n_frames"].append(audio_lengths[utt_id]) |
|
manifest["tgt_text"].append( |
|
src_utt if args.task == "asr" else tgt_utt |
|
) |
|
manifest["speaker"].append(spk_id) |
|
manifest["tgt_lang"].append(tgt_lang) |
|
if is_train_split: |
|
train_text.extend(manifest["tgt_text"]) |
|
df = pd.DataFrame.from_dict(manifest) |
|
df = filter_manifest_df(df, is_train_split=is_train_split) |
|
save_df_to_tsv(df, cur_root / f"{split}_{args.task}.tsv") |
|
|
|
v_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) |
|
spm_filename_prefix = f"spm_{args.vocab_type}{v_size_str}_{args.task}" |
|
with NamedTemporaryFile(mode="w") as f: |
|
for t in train_text: |
|
f.write(t + "\n") |
|
gen_vocab( |
|
Path(f.name), |
|
cur_root / spm_filename_prefix, |
|
args.vocab_type, |
|
args.vocab_size, |
|
) |
|
|
|
if args.use_audio_input: |
|
gen_config_yaml( |
|
cur_root, |
|
spm_filename=spm_filename_prefix + ".model", |
|
yaml_filename=f"config_{args.task}.yaml", |
|
specaugment_policy=None, |
|
extra={"use_audio_input": True} |
|
) |
|
else: |
|
gen_config_yaml( |
|
cur_root, |
|
spm_filename=spm_filename_prefix + ".model", |
|
yaml_filename=f"config_{args.task}.yaml", |
|
specaugment_policy="lb", |
|
) |
|
|
|
shutil.rmtree(audio_root) |
|
|
|
|
|
def process_joint(args): |
|
cur_root = Path(args.data_root) |
|
assert all((cur_root / f"{lang}").is_dir() for lang in mTEDx.LANGPAIRS), \ |
|
"do not have downloaded data available for all languages" |
|
|
|
vocab_size_str = "" if args.vocab_type == "char" else str(args.vocab_size) |
|
spm_filename_prefix = f"spm_{args.vocab_type}{vocab_size_str}_{args.task}" |
|
with NamedTemporaryFile(mode="w") as f: |
|
for lang in mTEDx.LANGPAIRS: |
|
tsv_path = cur_root / f"{lang}" / f"train_{args.task}.tsv" |
|
df = load_df_from_tsv(tsv_path) |
|
for t in df["tgt_text"]: |
|
f.write(t + "\n") |
|
special_symbols = None |
|
if args.joint: |
|
|
|
special_symbols = list( |
|
{f'<lang:{lang.split("-")[1]}>' for lang in mTEDx.LANGPAIRS} |
|
) |
|
gen_vocab( |
|
Path(f.name), |
|
cur_root / spm_filename_prefix, |
|
args.vocab_type, |
|
args.vocab_size, |
|
special_symbols=special_symbols |
|
) |
|
|
|
gen_config_yaml( |
|
cur_root, |
|
spm_filename=spm_filename_prefix + ".model", |
|
yaml_filename=f"config_{args.task}.yaml", |
|
specaugment_policy="ld", |
|
prepend_tgt_lang_tag=(args.joint), |
|
) |
|
|
|
for lang in mTEDx.LANGPAIRS: |
|
for split in mTEDx.SPLITS: |
|
src_path = cur_root / f"{lang}" / f"{split}_{args.task}.tsv" |
|
desc_path = cur_root / f"{split}_{lang}_{args.task}.tsv" |
|
if not desc_path.is_symlink(): |
|
os.symlink(src_path, desc_path) |
|
|
|
|
|
def main(): |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--data-root", "-d", required=True, type=str) |
|
parser.add_argument( |
|
"--vocab-type", |
|
default="unigram", |
|
required=True, |
|
type=str, |
|
choices=["bpe", "unigram", "char"], |
|
), |
|
parser.add_argument("--vocab-size", default=8000, type=int) |
|
parser.add_argument("--task", type=str, choices=["asr", "st"]) |
|
parser.add_argument("--joint", action="store_true", help="") |
|
parser.add_argument("--use-audio-input", action="store_true") |
|
args = parser.parse_args() |
|
|
|
if args.joint: |
|
process_joint(args) |
|
else: |
|
process(args) |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|