victan's picture
Upload seamless_communication/cli/m4t/finetune/dataset.py with huggingface_hub
519ab1b
raw
history blame
No virus
5.72 kB
# Copyright (c) Meta Platforms, Inc. and affiliates
# All rights reserved.
#
# This source code is licensed under the license found in the
# MIT_LICENSE file in the root directory of this source tree.
import argparse
import dataclasses
import json
import logging
import os
from pathlib import Path
import torch
from seamless_communication.datasets.huggingface import (
Speech2SpeechFleursDatasetBuilder,
SpeechTokenizer,
)
from seamless_communication.models.unit_extractor import UnitExtractor
logging.basicConfig(
level=logging.INFO,
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
)
logger = logging.getLogger("dataset")
# Full list of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
# Full list of M4T langcodes is available
# in paper "SeamlessM4T—Massively Multilingual & Multimodal Machine Translation" (Table 5)
UNITY_TO_FLEURS_LANG_MAPPING = {
"eng": "en_us",
"ita": "it_it",
"afr": "af_za",
"asm": "as_in",
"bel": "be_by",
"bul": "bg_bg",
"ben": "bn_in",
"cat": "ca_es",
"ces": "cs_cz",
"dan": "da_dk",
"deu": "de_de",
"ell": "el_gr",
"fin": "fi_fi",
"fra": "fr_fr",
"glg": "gl_es",
"heb": "he_il",
"hin": "hi_in",
"hrv": "hr_hr",
"hun": "hu_hu",
"ind": "id_id",
"ibo": "ig_ng",
"isl": "is_is",
"ita": "it_it",
"jpn": "ja_jp",
"jav": "jv_id",
"kaz": "kk_kz",
"kan": "kn_in",
"kir": "ky_kg",
"kor": "ko_kr",
"lit": "lt_lt",
"mkd": "mk_mk",
"mlt": "mt_mt",
"mya": "my_mm",
"nld": "nl_nl",
"pan": "pa_in",
"pol": "pl_pl",
"ron": "ro_ro",
"rus": "ru_ru",
"snd": "sd_in",
"slk": "sk_sk",
"srp": "sr_rs",
"swh": "sw_ke",
"tam": "ta_in",
"tel": "te_in",
"tha": "th_th",
"tur": "tr_tr",
"ukr": "uk_ua",
"urd": "ur_pk",
"uzn": "uz_uz",
"vie": "vi_vn",
"yor": "yo_ng",
"zul": "zu_za",
}
def _check_lang_code_mapping(lang: str) -> None:
if lang not in UNITY_TO_FLEURS_LANG_MAPPING:
raise ValueError(
f"No language code mapping for {lang}(M4T)->??(FLEURs). "
"Please expand `UNITY_TO_FLEURS_LANG_MAPPING`"
)
class UnitSpeechTokenizer(SpeechTokenizer):
MODEL_NAME = "xlsr2_1b_v2"
KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
OUTPUT_LAYER_IDX = 34
def __init__(self, device: torch.device):
super().__init__()
self.device = device
self.unit_extractor = UnitExtractor(
model_name_or_card=self.MODEL_NAME,
kmeans_uri=self.KMEANS_MODEL_URI,
device=self.device,
)
def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
return self.unit_extractor.predict(
wav.to(self.device),
out_layer_idx=self.OUTPUT_LAYER_IDX,
sample_rate=sample_rate,
)
def download_fleurs_dataset(
source_lang: str,
target_lang: str,
split: str,
save_directory: str,
) -> str:
_check_lang_code_mapping(source_lang)
_check_lang_code_mapping(target_lang)
device = (
torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
)
tokenizer = UnitSpeechTokenizer(device=device)
dataset_iterator = Speech2SpeechFleursDatasetBuilder(
source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
dataset_cache_dir=save_directory,
speech_tokenizer=tokenizer,
skip_source_audio=True, # don't extract units from source audio
skip_target_audio=False,
split=split,
)
manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
with open(manifest_path, "w") as fp_out:
for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
# correction as FleursDatasetBuilder return fleurs lang codes
sample.source.lang = source_lang
sample.target.lang = target_lang
sample.target.waveform = None # already extracted units
fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n")
logger.info(f"Saved {idx} samples for split={split} to {manifest_path}")
return manifest_path
def init_parser() -> argparse.ArgumentParser:
parser = argparse.ArgumentParser(
description=(
"Helper script to download training/evaluation dataset (FLEURS),"
"extract units from target audio and save the dataset as a manifest "
"consumable by `finetune.py`."
)
)
parser.add_argument(
"--source_lang",
type=str,
required=True,
help="M4T langcode of the dataset SOURCE language",
)
parser.add_argument(
"--target_lang",
type=str,
required=True,
help="M4T langcode of the dataset TARGET language",
)
parser.add_argument(
"--split",
type=str,
required=True,
help="Dataset split/shard to download (`train`, `validation`, `test`)",
)
parser.add_argument(
"--save_dir",
type=Path,
required=True,
help="Directory where the datastets will be stored with HuggingFace datasets cache files",
)
return parser
def main() -> None:
args = init_parser().parse_args()
manifest_path = download_fleurs_dataset(
source_lang=args.source_lang,
target_lang=args.target_lang,
split=args.split,
save_directory=args.save_dir,
)
logger.info(f"Manifest saved to: {manifest_path}")
if __name__ == "__main__":
main()