|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import argparse |
|
import dataclasses |
|
import json |
|
import logging |
|
import os |
|
from pathlib import Path |
|
|
|
import torch |
|
|
|
from seamless_communication.datasets.huggingface import ( |
|
Speech2SpeechFleursDatasetBuilder, |
|
SpeechTokenizer, |
|
) |
|
from seamless_communication.models.unit_extractor import UnitExtractor |
|
|
|
logging.basicConfig( |
|
level=logging.INFO, |
|
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s", |
|
) |
|
|
|
logger = logging.getLogger("dataset") |
|
|
|
|
|
|
|
|
|
|
|
UNITY_TO_FLEURS_LANG_MAPPING = { |
|
"eng": "en_us", |
|
"ita": "it_it", |
|
"afr": "af_za", |
|
"asm": "as_in", |
|
"bel": "be_by", |
|
"bul": "bg_bg", |
|
"ben": "bn_in", |
|
"cat": "ca_es", |
|
"ces": "cs_cz", |
|
"dan": "da_dk", |
|
"deu": "de_de", |
|
"ell": "el_gr", |
|
"fin": "fi_fi", |
|
"fra": "fr_fr", |
|
"glg": "gl_es", |
|
"heb": "he_il", |
|
"hin": "hi_in", |
|
"hrv": "hr_hr", |
|
"hun": "hu_hu", |
|
"ind": "id_id", |
|
"ibo": "ig_ng", |
|
"isl": "is_is", |
|
"ita": "it_it", |
|
"jpn": "ja_jp", |
|
"jav": "jv_id", |
|
"kaz": "kk_kz", |
|
"kan": "kn_in", |
|
"kir": "ky_kg", |
|
"kor": "ko_kr", |
|
"lit": "lt_lt", |
|
"mkd": "mk_mk", |
|
"mlt": "mt_mt", |
|
"mya": "my_mm", |
|
"nld": "nl_nl", |
|
"pan": "pa_in", |
|
"pol": "pl_pl", |
|
"ron": "ro_ro", |
|
"rus": "ru_ru", |
|
"snd": "sd_in", |
|
"slk": "sk_sk", |
|
"srp": "sr_rs", |
|
"swh": "sw_ke", |
|
"tam": "ta_in", |
|
"tel": "te_in", |
|
"tha": "th_th", |
|
"tur": "tr_tr", |
|
"ukr": "uk_ua", |
|
"urd": "ur_pk", |
|
"uzn": "uz_uz", |
|
"vie": "vi_vn", |
|
"yor": "yo_ng", |
|
"zul": "zu_za", |
|
} |
|
|
|
|
|
def _check_lang_code_mapping(lang: str) -> None: |
|
if lang not in UNITY_TO_FLEURS_LANG_MAPPING: |
|
raise ValueError( |
|
f"No language code mapping for {lang}(M4T)->??(FLEURs). " |
|
"Please expand `UNITY_TO_FLEURS_LANG_MAPPING`" |
|
) |
|
|
|
|
|
class UnitSpeechTokenizer(SpeechTokenizer): |
|
MODEL_NAME = "xlsr2_1b_v2" |
|
KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy" |
|
OUTPUT_LAYER_IDX = 34 |
|
|
|
def __init__(self, device: torch.device): |
|
super().__init__() |
|
self.device = device |
|
self.unit_extractor = UnitExtractor( |
|
model_name_or_card=self.MODEL_NAME, |
|
kmeans_uri=self.KMEANS_MODEL_URI, |
|
device=self.device, |
|
) |
|
|
|
def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor: |
|
return self.unit_extractor.predict( |
|
wav.to(self.device), |
|
out_layer_idx=self.OUTPUT_LAYER_IDX, |
|
sample_rate=sample_rate, |
|
) |
|
|
|
|
|
def download_fleurs_dataset( |
|
source_lang: str, |
|
target_lang: str, |
|
split: str, |
|
save_directory: str, |
|
) -> str: |
|
_check_lang_code_mapping(source_lang) |
|
_check_lang_code_mapping(target_lang) |
|
device = ( |
|
torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu") |
|
) |
|
tokenizer = UnitSpeechTokenizer(device=device) |
|
dataset_iterator = Speech2SpeechFleursDatasetBuilder( |
|
source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang], |
|
target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang], |
|
dataset_cache_dir=save_directory, |
|
speech_tokenizer=tokenizer, |
|
skip_source_audio=True, |
|
skip_target_audio=False, |
|
split=split, |
|
) |
|
manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json") |
|
with open(manifest_path, "w") as fp_out: |
|
for idx, sample in enumerate(dataset_iterator.__iter__(), start=1): |
|
|
|
sample.source.lang = source_lang |
|
sample.target.lang = target_lang |
|
sample.target.waveform = None |
|
fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n") |
|
logger.info(f"Saved {idx} samples for split={split} to {manifest_path}") |
|
return manifest_path |
|
|
|
|
|
def init_parser() -> argparse.ArgumentParser: |
|
parser = argparse.ArgumentParser( |
|
description=( |
|
"Helper script to download training/evaluation dataset (FLEURS)," |
|
"extract units from target audio and save the dataset as a manifest " |
|
"consumable by `finetune.py`." |
|
) |
|
) |
|
parser.add_argument( |
|
"--source_lang", |
|
type=str, |
|
required=True, |
|
help="M4T langcode of the dataset SOURCE language", |
|
) |
|
parser.add_argument( |
|
"--target_lang", |
|
type=str, |
|
required=True, |
|
help="M4T langcode of the dataset TARGET language", |
|
) |
|
parser.add_argument( |
|
"--split", |
|
type=str, |
|
required=True, |
|
help="Dataset split/shard to download (`train`, `validation`, `test`)", |
|
) |
|
parser.add_argument( |
|
"--save_dir", |
|
type=Path, |
|
required=True, |
|
help="Directory where the datastets will be stored with HuggingFace datasets cache files", |
|
) |
|
return parser |
|
|
|
|
|
def main() -> None: |
|
args = init_parser().parse_args() |
|
manifest_path = download_fleurs_dataset( |
|
source_lang=args.source_lang, |
|
target_lang=args.target_lang, |
|
split=args.split, |
|
save_directory=args.save_dir, |
|
) |
|
logger.info(f"Manifest saved to: {manifest_path}") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|