victan commited on
Commit
519ab1b
1 Parent(s): b84aa12

Upload seamless_communication/cli/m4t/finetune/dataset.py with huggingface_hub

Browse files
seamless_communication/cli/m4t/finetune/dataset.py ADDED
@@ -0,0 +1,200 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates
2
+ # All rights reserved.
3
+ #
4
+ # This source code is licensed under the license found in the
5
+ # MIT_LICENSE file in the root directory of this source tree.
6
+
7
+
8
+ import argparse
9
+ import dataclasses
10
+ import json
11
+ import logging
12
+ import os
13
+ from pathlib import Path
14
+
15
+ import torch
16
+
17
+ from seamless_communication.datasets.huggingface import (
18
+ Speech2SpeechFleursDatasetBuilder,
19
+ SpeechTokenizer,
20
+ )
21
+ from seamless_communication.models.unit_extractor import UnitExtractor
22
+
23
+ logging.basicConfig(
24
+ level=logging.INFO,
25
+ format="%(asctime)s %(levelname)s -- %(name)s: %(message)s",
26
+ )
27
+
28
+ logger = logging.getLogger("dataset")
29
+
30
+
31
+ # Full list of FLEURS langcodes is available at https://huggingface.co/datasets/google/fleurs
32
+ # Full list of M4T langcodes is available
33
+ # in paper "SeamlessM4T—Massively Multilingual & Multimodal Machine Translation" (Table 5)
34
+ UNITY_TO_FLEURS_LANG_MAPPING = {
35
+ "eng": "en_us",
36
+ "ita": "it_it",
37
+ "afr": "af_za",
38
+ "asm": "as_in",
39
+ "bel": "be_by",
40
+ "bul": "bg_bg",
41
+ "ben": "bn_in",
42
+ "cat": "ca_es",
43
+ "ces": "cs_cz",
44
+ "dan": "da_dk",
45
+ "deu": "de_de",
46
+ "ell": "el_gr",
47
+ "fin": "fi_fi",
48
+ "fra": "fr_fr",
49
+ "glg": "gl_es",
50
+ "heb": "he_il",
51
+ "hin": "hi_in",
52
+ "hrv": "hr_hr",
53
+ "hun": "hu_hu",
54
+ "ind": "id_id",
55
+ "ibo": "ig_ng",
56
+ "isl": "is_is",
57
+ "ita": "it_it",
58
+ "jpn": "ja_jp",
59
+ "jav": "jv_id",
60
+ "kaz": "kk_kz",
61
+ "kan": "kn_in",
62
+ "kir": "ky_kg",
63
+ "kor": "ko_kr",
64
+ "lit": "lt_lt",
65
+ "mkd": "mk_mk",
66
+ "mlt": "mt_mt",
67
+ "mya": "my_mm",
68
+ "nld": "nl_nl",
69
+ "pan": "pa_in",
70
+ "pol": "pl_pl",
71
+ "ron": "ro_ro",
72
+ "rus": "ru_ru",
73
+ "snd": "sd_in",
74
+ "slk": "sk_sk",
75
+ "srp": "sr_rs",
76
+ "swh": "sw_ke",
77
+ "tam": "ta_in",
78
+ "tel": "te_in",
79
+ "tha": "th_th",
80
+ "tur": "tr_tr",
81
+ "ukr": "uk_ua",
82
+ "urd": "ur_pk",
83
+ "uzn": "uz_uz",
84
+ "vie": "vi_vn",
85
+ "yor": "yo_ng",
86
+ "zul": "zu_za",
87
+ }
88
+
89
+
90
+ def _check_lang_code_mapping(lang: str) -> None:
91
+ if lang not in UNITY_TO_FLEURS_LANG_MAPPING:
92
+ raise ValueError(
93
+ f"No language code mapping for {lang}(M4T)->??(FLEURs). "
94
+ "Please expand `UNITY_TO_FLEURS_LANG_MAPPING`"
95
+ )
96
+
97
+
98
+ class UnitSpeechTokenizer(SpeechTokenizer):
99
+ MODEL_NAME = "xlsr2_1b_v2"
100
+ KMEANS_MODEL_URI = "https://dl.fbaipublicfiles.com/seamlessM4T/models/unit_extraction/kmeans_10k.npy"
101
+ OUTPUT_LAYER_IDX = 34
102
+
103
+ def __init__(self, device: torch.device):
104
+ super().__init__()
105
+ self.device = device
106
+ self.unit_extractor = UnitExtractor(
107
+ model_name_or_card=self.MODEL_NAME,
108
+ kmeans_uri=self.KMEANS_MODEL_URI,
109
+ device=self.device,
110
+ )
111
+
112
+ def encode(self, wav: torch.Tensor, sample_rate: int) -> torch.Tensor:
113
+ return self.unit_extractor.predict(
114
+ wav.to(self.device),
115
+ out_layer_idx=self.OUTPUT_LAYER_IDX,
116
+ sample_rate=sample_rate,
117
+ )
118
+
119
+
120
+ def download_fleurs_dataset(
121
+ source_lang: str,
122
+ target_lang: str,
123
+ split: str,
124
+ save_directory: str,
125
+ ) -> str:
126
+ _check_lang_code_mapping(source_lang)
127
+ _check_lang_code_mapping(target_lang)
128
+ device = (
129
+ torch.device("cuda:0") if torch.cuda.device_count() > 0 else torch.device("cpu")
130
+ )
131
+ tokenizer = UnitSpeechTokenizer(device=device)
132
+ dataset_iterator = Speech2SpeechFleursDatasetBuilder(
133
+ source_lang=UNITY_TO_FLEURS_LANG_MAPPING[source_lang],
134
+ target_lang=UNITY_TO_FLEURS_LANG_MAPPING[target_lang],
135
+ dataset_cache_dir=save_directory,
136
+ speech_tokenizer=tokenizer,
137
+ skip_source_audio=True, # don't extract units from source audio
138
+ skip_target_audio=False,
139
+ split=split,
140
+ )
141
+ manifest_path: str = os.path.join(save_directory, f"{split}_manifest.json")
142
+ with open(manifest_path, "w") as fp_out:
143
+ for idx, sample in enumerate(dataset_iterator.__iter__(), start=1):
144
+ # correction as FleursDatasetBuilder return fleurs lang codes
145
+ sample.source.lang = source_lang
146
+ sample.target.lang = target_lang
147
+ sample.target.waveform = None # already extracted units
148
+ fp_out.write(json.dumps(dataclasses.asdict(sample)) + "\n")
149
+ logger.info(f"Saved {idx} samples for split={split} to {manifest_path}")
150
+ return manifest_path
151
+
152
+
153
+ def init_parser() -> argparse.ArgumentParser:
154
+ parser = argparse.ArgumentParser(
155
+ description=(
156
+ "Helper script to download training/evaluation dataset (FLEURS),"
157
+ "extract units from target audio and save the dataset as a manifest "
158
+ "consumable by `finetune.py`."
159
+ )
160
+ )
161
+ parser.add_argument(
162
+ "--source_lang",
163
+ type=str,
164
+ required=True,
165
+ help="M4T langcode of the dataset SOURCE language",
166
+ )
167
+ parser.add_argument(
168
+ "--target_lang",
169
+ type=str,
170
+ required=True,
171
+ help="M4T langcode of the dataset TARGET language",
172
+ )
173
+ parser.add_argument(
174
+ "--split",
175
+ type=str,
176
+ required=True,
177
+ help="Dataset split/shard to download (`train`, `validation`, `test`)",
178
+ )
179
+ parser.add_argument(
180
+ "--save_dir",
181
+ type=Path,
182
+ required=True,
183
+ help="Directory where the datastets will be stored with HuggingFace datasets cache files",
184
+ )
185
+ return parser
186
+
187
+
188
+ def main() -> None:
189
+ args = init_parser().parse_args()
190
+ manifest_path = download_fleurs_dataset(
191
+ source_lang=args.source_lang,
192
+ target_lang=args.target_lang,
193
+ split=args.split,
194
+ save_directory=args.save_dir,
195
+ )
196
+ logger.info(f"Manifest saved to: {manifest_path}")
197
+
198
+
199
+ if __name__ == "__main__":
200
+ main()