Upload seamless_communication/cli/m4t/finetune/dataset.py with huggingface_hub
Browse files
seamless_communication/cli/m4t/finetune/dataset.py
ADDED
@@ -0,0 +1,200 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# MIT_LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
|
8 |
+
import argparse
|
9 |
+
import dataclasses
|
10 |
+
import json
|
11 |
+
import logging
|
12 |
+
import os
|
13 |
+
from pathlib import Path
|
14 |
+
|
15 |
+
import torch
|
16 |
+
|
17 |
+
from seamless_communication.datasets.huggingface import (
|
18 |
+
Speech2SpeechFleursDatasetBuilder,
|
19 |
+
SpeechTokenizer,
|
20 |
+
)
|
21 |
+
from seamless_communication.models.unit_extractor import UnitExtractor
|
22 |
+
|
23 |
+
logging.basicConfig(
|
24 |
+
level=logging.INFO,
|
25 |
+
format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
|
26 |
+
)
|
27 |
+
|
28 |
+
logger = logging.getLogger("dataset")
|
29 |
+
|
30 |
+
|
31 |
+
# Full list of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
|
32 |
+
# Full list of M4T langcodes is available
|
33 |
+
# in paper "SeamlessM4T—Massively Multilingual & Multimodal Machine Translation" (Table 5)
|
34 |
+
UNITY_TO_FLEURS_LANG_MAPPING = {
|
35 |
+
"eng": "en_us",
|
36 |
+
"ita": "it_it",
|
37 |
+
"afr": "af_za",
|
38 |
+
"asm": "as_in",
|
39 |
+
"bel": "be_by",
|
40 |
+
"bul": "bg_bg",
|
41 |
+
"ben": "bn_in",
|
42 |
+
"cat": "ca_es",
|
43 |
+
"ces": "cs_cz",
|
44 |
+
"dan": "da_dk",
|
45 |
+
"deu": "de_de",
|
46 |
+
"ell": "el_gr",
|
47 |
+
"fin": "fi_fi",
|
48 |
+
"fra": "fr_fr",
|
49 |
+
"glg": "gl_es",
|
50 |
+
"heb": "he_il",
|
51 |
+
"hin": "hi_in",
|
52 |
+
"hrv": "hr_hr",
|
53 |
+
"hun": "hu_hu",
|
54 |
+
"ind": "id_id",
|
55 |
+
"ibo": "ig_ng",
|
56 |
+
"isl": "is_is",
|
57 |
+
"ita": "it_it",
|
58 |
+
"jpn": "ja_jp",
|
59 |
+
"jav": "jv_id",
|
60 |
+
"kaz": "kk_kz",
|
61 |
+
"kan": "kn_in",
|
62 |
+
"kir": "ky_kg",
|
63 |
+
"kor": "ko_kr",
|
64 |
+
"lit": "lt_lt",
|
65 |
+
"mkd": "mk_mk",
|
66 |
+
"mlt": "mt_mt",
|
67 |
+
"mya": "my_mm",
|
68 |
+
"nld": "nl_nl",
|
69 |
+
"pan": "pa_in",
|
70 |
+
"pol": "pl_pl",
|
71 |
+
"ron": "ro_ro",
|
72 |
+
"rus": "ru_ru",
|
73 |
+
"snd": "sd_in",
|
74 |
+
"slk": "sk_sk",
|
75 |
+
"srp": "sr_rs",
|
76 |
+
"swh": "sw_ke",
|
77 |
+
"tam": "ta_in",
|
78 |
+
"tel": "te_in",
|
79 |
+
"tha": "th_th",
|
80 |
+
"tur": "tr_tr",
|
81 |
+
"ukr": "uk_ua",
|
82 |
+
"urd": "ur_pk",
|
83 |
+
"uzn": "uz_uz",
|
84 |
+
"vie": "vi_vn",
|
85 |
+
"yor": "yo_ng",
|
86 |
+
"zul": "zu_za",
|
87 |
+
}
|
88 |
+
|
89 |
+
|
90 |
+
def _check_lang_code_mapping(lang: str) -> None:
|
91 |
+
if lang not in UNITY_TO_FLEURS_LANG_MAPPING:
|
92 |
+
raise ValueError(
|
93 |
+
f"No language code mapping for {lang}(M4T)->??(FLEURs). "
|
94 |
+
"Please expand `UNITY_TO_FLEURS_LANG_MAPPING`"
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
class UnitSpeechTokenizer(SpeechTokenizer):
|
99 |
+
MODEL_NAME = "xlsr2_1b_v2"
|
100 |
+
KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
|
101 |
+
OUTPUT_LAYER_IDX = 34
|
102 |
+
|
103 |
+
def __init__(self, device: torch.device):
|
104 |
+
super().__init__()
|
105 |
+
self.device = device
|
106 |
+
self.unit_extractor = UnitExtractor(
|
107 |
+
model_name_or_card=self.MODEL_NAME,
|
108 |
+
kmeans_uri=self.KMEANS_MODEL_URI,
|
109 |
+
device=self.device,
|
110 |
+
)
|
111 |
+
|
112 |
+
def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
|
113 |
+
return self.unit_extractor.predict(
|
114 |
+
wav.to(self.device),
|
115 |
+
out_layer_idx=self.OUTPUT_LAYER_IDX,
|
116 |
+
sample_rate=sample_rate,
|
117 |
+
)
|
118 |
+
|
119 |
+
|
120 |
+
def download_fleurs_dataset(
|
121 |
+
source_lang: str,
|
122 |
+
target_lang: str,
|
123 |
+
split: str,
|
124 |
+
save_directory: str,
|
125 |
+
) -> str:
|
126 |
+
_check_lang_code_mapping(source_lang)
|
127 |
+
_check_lang_code_mapping(target_lang)
|
128 |
+
device = (
|
129 |
+
torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
|
130 |
+
)
|
131 |
+
tokenizer = UnitSpeechTokenizer(device=device)
|
132 |
+
dataset_iterator = Speech2SpeechFleursDatasetBuilder(
|
133 |
+
source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
|
134 |
+
target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
|
135 |
+
dataset_cache_dir=save_directory,
|
136 |
+
speech_tokenizer=tokenizer,
|
137 |
+
skip_source_audio=True, # don't extract units from source audio
|
138 |
+
skip_target_audio=False,
|
139 |
+
split=split,
|
140 |
+
)
|
141 |
+
manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
|
142 |
+
with open(manifest_path, "w") as fp_out:
|
143 |
+
for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
|
144 |
+
# correction as FleursDatasetBuilder return fleurs lang codes
|
145 |
+
sample.source.lang = source_lang
|
146 |
+
sample.target.lang = target_lang
|
147 |
+
sample.target.waveform = None # already extracted units
|
148 |
+
fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n")
|
149 |
+
logger.info(f"Saved {idx} samples for split={split} to {manifest_path}")
|
150 |
+
return manifest_path
|
151 |
+
|
152 |
+
|
153 |
+
def init_parser() -> argparse.ArgumentParser:
|
154 |
+
parser = argparse.ArgumentParser(
|
155 |
+
description=(
|
156 |
+
"Helper script to download training/evaluation dataset (FLEURS),"
|
157 |
+
"extract units from target audio and save the dataset as a manifest "
|
158 |
+
"consumable by `finetune.py`."
|
159 |
+
)
|
160 |
+
)
|
161 |
+
parser.add_argument(
|
162 |
+
"--source_lang",
|
163 |
+
type=str,
|
164 |
+
required=True,
|
165 |
+
help="M4T langcode of the dataset SOURCE language",
|
166 |
+
)
|
167 |
+
parser.add_argument(
|
168 |
+
"--target_lang",
|
169 |
+
type=str,
|
170 |
+
required=True,
|
171 |
+
help="M4T langcode of the dataset TARGET language",
|
172 |
+
)
|
173 |
+
parser.add_argument(
|
174 |
+
"--split",
|
175 |
+
type=str,
|
176 |
+
required=True,
|
177 |
+
help="Dataset split/shard to download (`train`, `validation`, `test`)",
|
178 |
+
)
|
179 |
+
parser.add_argument(
|
180 |
+
"--save_dir",
|
181 |
+
type=Path,
|
182 |
+
required=True,
|
183 |
+
help="Directory where the datastets will be stored with HuggingFace datasets cache files",
|
184 |
+
)
|
185 |
+
return parser
|
186 |
+
|
187 |
+
|
188 |
+
def main() -> None:
|
189 |
+
args = init_parser().parse_args()
|
190 |
+
manifest_path = download_fleurs_dataset(
|
191 |
+
source_lang=args.source_lang,
|
192 |
+
target_lang=args.target_lang,
|
193 |
+
split=args.split,
|
194 |
+
save_directory=args.save_dir,
|
195 |
+
)
|
196 |
+
logger.info(f"Manifest saved to: {manifest_path}")
|
197 |
+
|
198 |
+
|
199 |
+
if __name__ == "__main__":
|
200 |
+
main()
|