JustinLin610's picture
first commit
ee21b96
raw
history blame
No virus
2.69 kB
# Copyright (c) Facebook, Inc. and its affiliates.
#
# This source code is licensed under the MIT license found in the
# LICENSE file in the root directory of this source tree.
import argparse
import logging
import numpy as np
import re
from pathlib import Path
from collections import defaultdict
import pandas as pd
from torchaudio.datasets import VCTK
from tqdm import tqdm
from examples.speech_to_text.data_utils import save_df_to_tsv
log = logging.getLogger(__name__)
SPLITS = ["train", "dev", "test"]
def normalize_text(text):
return re.sub(r"[^a-zA-Z.?!,'\- ]", '', text)
def process(args):
out_root = Path(args.output_data_root).absolute()
out_root.mkdir(parents=True, exist_ok=True)
# Generate TSV manifest
print("Generating manifest...")
dataset = VCTK(out_root.as_posix(), download=False)
ids = list(dataset._walker)
np.random.seed(args.seed)
np.random.shuffle(ids)
n_train = len(ids) - args.n_dev - args.n_test
_split = ["train"] * n_train + ["dev"] * args.n_dev + ["test"] * args.n_test
id_to_split = dict(zip(ids, _split))
manifest_by_split = {split: defaultdict(list) for split in SPLITS}
progress = tqdm(enumerate(dataset), total=len(dataset))
for i, (waveform, _, text, speaker_id, _) in progress:
sample_id = dataset._walker[i]
_split = id_to_split[sample_id]
audio_dir = Path(dataset._path) / dataset._folder_audio / speaker_id
audio_path = audio_dir / f"{sample_id}.wav"
text = normalize_text(text)
manifest_by_split[_split]["id"].append(sample_id)
manifest_by_split[_split]["audio"].append(audio_path.as_posix())
manifest_by_split[_split]["n_frames"].append(len(waveform[0]))
manifest_by_split[_split]["tgt_text"].append(text)
manifest_by_split[_split]["speaker"].append(speaker_id)
manifest_by_split[_split]["src_text"].append(text)
manifest_root = Path(args.output_manifest_root).absolute()
manifest_root.mkdir(parents=True, exist_ok=True)
for _split in SPLITS:
save_df_to_tsv(
pd.DataFrame.from_dict(manifest_by_split[_split]),
manifest_root / f"{_split}.audio.tsv"
)
def main():
parser = argparse.ArgumentParser()
parser.add_argument("--output-data-root", "-d", required=True, type=str)
parser.add_argument("--output-manifest-root", "-m", required=True, type=str)
parser.add_argument("--n-dev", default=50, type=int)
parser.add_argument("--n-test", default=100, type=int)
parser.add_argument("--seed", "-s", default=1234, type=int)
args = parser.parse_args()
process(args)
if __name__ == "__main__":
main()