jxtan commited on
Commit
76e433a
1 Parent(s): c828814

Initial Commit to Test Runpod

Browse files
Dockerfile ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvidia/cuda:12.1.1-cudnn8-devel-ubuntu22.04
2
+ ENV DEBIAN_FRONTEND=noninteractive
3
+
4
+ RUN useradd -m -u 1000 user
5
+ USER user
6
+ ENV HOME=/home/user \
7
+ PATH=/home/user/.local/bin:${PATH}
8
+ WORKDIR ${HOME}/app
9
+
10
+ COPY --chown=1000 . ${HOME}/app
11
+ RUN pip install -r ${HOME}/app/requirements.txt && \
12
+ pip install fairseq2 --pre --extra-index-url https://fair.pkg.atmeta.com/fairseq2/pt2.1.0/cu121 && \
13
+ pip install ${HOME}/app/whl/seamless_communication-1.0.0-py3-none-any.whl
14
+ # This will cache the model into the docker image
15
+ RUN python -u translator.py
16
+ CMD ["python", "main.py"]
lang_list.py ADDED
@@ -0,0 +1,255 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Language dict
2
+ language_code_to_name = {
3
+ "afr": "Afrikaans",
4
+ "amh": "Amharic",
5
+ "arb": "Modern Standard Arabic",
6
+ "ary": "Moroccan Arabic",
7
+ "arz": "Egyptian Arabic",
8
+ "asm": "Assamese",
9
+ "ast": "Asturian",
10
+ "azj": "North Azerbaijani",
11
+ "bel": "Belarusian",
12
+ "ben": "Bengali",
13
+ "bos": "Bosnian",
14
+ "bul": "Bulgarian",
15
+ "cat": "Catalan",
16
+ "ceb": "Cebuano",
17
+ "ces": "Czech",
18
+ "ckb": "Central Kurdish",
19
+ "cmn": "Mandarin Chinese",
20
+ "cym": "Welsh",
21
+ "dan": "Danish",
22
+ "deu": "German",
23
+ "ell": "Greek",
24
+ "eng": "English",
25
+ "est": "Estonian",
26
+ "eus": "Basque",
27
+ "fin": "Finnish",
28
+ "fra": "French",
29
+ "gaz": "West Central Oromo",
30
+ "gle": "Irish",
31
+ "glg": "Galician",
32
+ "guj": "Gujarati",
33
+ "heb": "Hebrew",
34
+ "hin": "Hindi",
35
+ "hrv": "Croatian",
36
+ "hun": "Hungarian",
37
+ "hye": "Armenian",
38
+ "ibo": "Igbo",
39
+ "ind": "Indonesian",
40
+ "isl": "Icelandic",
41
+ "ita": "Italian",
42
+ "jav": "Javanese",
43
+ "jpn": "Japanese",
44
+ "kam": "Kamba",
45
+ "kan": "Kannada",
46
+ "kat": "Georgian",
47
+ "kaz": "Kazakh",
48
+ "kea": "Kabuverdianu",
49
+ "khk": "Halh Mongolian",
50
+ "khm": "Khmer",
51
+ "kir": "Kyrgyz",
52
+ "kor": "Korean",
53
+ "lao": "Lao",
54
+ "lit": "Lithuanian",
55
+ "ltz": "Luxembourgish",
56
+ "lug": "Ganda",
57
+ "luo": "Luo",
58
+ "lvs": "Standard Latvian",
59
+ "mai": "Maithili",
60
+ "mal": "Malayalam",
61
+ "mar": "Marathi",
62
+ "mkd": "Macedonian",
63
+ "mlt": "Maltese",
64
+ "mni": "Meitei",
65
+ "mya": "Burmese",
66
+ "nld": "Dutch",
67
+ "nno": "Norwegian Nynorsk",
68
+ "nob": "Norwegian Bokm\u00e5l",
69
+ "npi": "Nepali",
70
+ "nya": "Nyanja",
71
+ "oci": "Occitan",
72
+ "ory": "Odia",
73
+ "pan": "Punjabi",
74
+ "pbt": "Southern Pashto",
75
+ "pes": "Western Persian",
76
+ "pol": "Polish",
77
+ "por": "Portuguese",
78
+ "ron": "Romanian",
79
+ "rus": "Russian",
80
+ "slk": "Slovak",
81
+ "slv": "Slovenian",
82
+ "sna": "Shona",
83
+ "snd": "Sindhi",
84
+ "som": "Somali",
85
+ "spa": "Spanish",
86
+ "srp": "Serbian",
87
+ "swe": "Swedish",
88
+ "swh": "Swahili",
89
+ "tam": "Tamil",
90
+ "tel": "Telugu",
91
+ "tgk": "Tajik",
92
+ "tgl": "Tagalog",
93
+ "tha": "Thai",
94
+ "tur": "Turkish",
95
+ "ukr": "Ukrainian",
96
+ "urd": "Urdu",
97
+ "uzn": "Northern Uzbek",
98
+ "vie": "Vietnamese",
99
+ "xho": "Xhosa",
100
+ "yor": "Yoruba",
101
+ "yue": "Cantonese",
102
+ "zlm": "Colloquial Malay",
103
+ "zsm": "Standard Malay",
104
+ "zul": "Zulu",
105
+ }
106
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
107
+
108
+ # Source langs: S2ST / S2TT / ASR don't need source lang
109
+ # T2TT / T2ST use this
110
+ text_source_language_codes = [
111
+ "afr",
112
+ "amh",
113
+ "arb",
114
+ "ary",
115
+ "arz",
116
+ "asm",
117
+ "azj",
118
+ "bel",
119
+ "ben",
120
+ "bos",
121
+ "bul",
122
+ "cat",
123
+ "ceb",
124
+ "ces",
125
+ "ckb",
126
+ "cmn",
127
+ "cym",
128
+ "dan",
129
+ "deu",
130
+ "ell",
131
+ "eng",
132
+ "est",
133
+ "eus",
134
+ "fin",
135
+ "fra",
136
+ "gaz",
137
+ "gle",
138
+ "glg",
139
+ "guj",
140
+ "heb",
141
+ "hin",
142
+ "hrv",
143
+ "hun",
144
+ "hye",
145
+ "ibo",
146
+ "ind",
147
+ "isl",
148
+ "ita",
149
+ "jav",
150
+ "jpn",
151
+ "kan",
152
+ "kat",
153
+ "kaz",
154
+ "khk",
155
+ "khm",
156
+ "kir",
157
+ "kor",
158
+ "lao",
159
+ "lit",
160
+ "lug",
161
+ "luo",
162
+ "lvs",
163
+ "mai",
164
+ "mal",
165
+ "mar",
166
+ "mkd",
167
+ "mlt",
168
+ "mni",
169
+ "mya",
170
+ "nld",
171
+ "nno",
172
+ "nob",
173
+ "npi",
174
+ "nya",
175
+ "ory",
176
+ "pan",
177
+ "pbt",
178
+ "pes",
179
+ "pol",
180
+ "por",
181
+ "ron",
182
+ "rus",
183
+ "slk",
184
+ "slv",
185
+ "sna",
186
+ "snd",
187
+ "som",
188
+ "spa",
189
+ "srp",
190
+ "swe",
191
+ "swh",
192
+ "tam",
193
+ "tel",
194
+ "tgk",
195
+ "tgl",
196
+ "tha",
197
+ "tur",
198
+ "ukr",
199
+ "urd",
200
+ "uzn",
201
+ "vie",
202
+ "yor",
203
+ "yue",
204
+ "zsm",
205
+ "zul",
206
+ ]
207
+ TEXT_SOURCE_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in text_source_language_codes])
208
+
209
+ # Target langs:
210
+ # S2ST / T2ST
211
+ s2st_target_language_codes = [
212
+ "eng",
213
+ "arb",
214
+ "ben",
215
+ "cat",
216
+ "ces",
217
+ "cmn",
218
+ "cym",
219
+ "dan",
220
+ "deu",
221
+ "est",
222
+ "fin",
223
+ "fra",
224
+ "hin",
225
+ "ind",
226
+ "ita",
227
+ "jpn",
228
+ "kor",
229
+ "mlt",
230
+ "nld",
231
+ "pes",
232
+ "pol",
233
+ "por",
234
+ "ron",
235
+ "rus",
236
+ "slk",
237
+ "spa",
238
+ "swe",
239
+ "swh",
240
+ "tel",
241
+ "tgl",
242
+ "tha",
243
+ "tur",
244
+ "ukr",
245
+ "urd",
246
+ "uzn",
247
+ "vie",
248
+ ]
249
+ S2ST_TARGET_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in s2st_target_language_codes])
250
+ T2ST_TARGET_LANGUAGE_NAMES = S2ST_TARGET_LANGUAGE_NAMES
251
+
252
+ # S2TT / T2TT / ASR
253
+ S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
254
+ T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
255
+ ASR_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
main.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # from translator import translator
2
+ from lang_list import LANGUAGE_NAME_TO_CODE
3
+ import runpod
4
+
5
+ def run_t2tt(input_text: str, source_language: str, target_language: str) -> str:
6
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
7
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
8
+ # out_texts, _ = translator.predict(
9
+ # input=input_text,
10
+ # task_str="T2TT",
11
+ # src_lang=source_language_code,
12
+ # tgt_lang=target_language_code,
13
+ # )
14
+ # return str(out_texts[0])
15
+ import json
16
+ return json.dumps({"input_text": input_text, "src_code": source_language_code, "tgt_code": target_language_code})
17
+
18
+ def runpod_handler(job):
19
+ job_input = job['input']
20
+ input_text = job_input["input_text"]
21
+ source_language = job_input["source_language"]
22
+ target_language = job_input["target_language"]
23
+ return run_t2tt(input_text, source_language, target_language)
24
+
25
+ runpod.serverless.start({"handler": runpod_handler})
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ gradio==4.9.0
2
+ omegaconf==2.3.0
3
+ torch==2.1.0
4
+ torchaudio==2.1.0
5
+ runpod
test_input.json ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ {
2
+ "input": {
3
+ "input_text": "How are you doing today?",
4
+ "source_language": "English",
5
+ "target_language": "Mandarin Chinese"
6
+ }
7
+ }
translator.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import pathlib
3
+ import torch
4
+ from fairseq2.assets import InProcAssetMetadataProvider, asset_store
5
+ from huggingface_hub import snapshot_download
6
+ from seamless_communication.inference import Translator
7
+
8
+ CHECKPOINTS_PATH = pathlib.Path(os.getenv("CHECKPOINTS_PATH", "/home/user/app/models"))
9
+ if not CHECKPOINTS_PATH.exists():
10
+ snapshot_download(repo_id="facebook/seamless-m4t-v2-large", repo_type="model", local_dir=CHECKPOINTS_PATH)
11
+ asset_store.env_resolvers.clear()
12
+ asset_store.env_resolvers.append(lambda: "demo")
13
+ demo_metadata = [
14
+ {
15
+ "name": "seamlessM4T_v2_large@demo",
16
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/seamlessM4T_v2_large.pt",
17
+ "char_tokenizer": f"file://{CHECKPOINTS_PATH}/spm_char_lang38_tc.model",
18
+ },
19
+ {
20
+ "name": "vocoder_v2@demo",
21
+ "checkpoint": f"file://{CHECKPOINTS_PATH}/vocoder_v2.pt",
22
+ },
23
+ ]
24
+ asset_store.metadata_providers.append(InProcAssetMetadataProvider(demo_metadata))
25
+
26
+ if torch.cuda.is_available():
27
+ device = torch.device("cuda:0")
28
+ dtype = torch.float16
29
+ else:
30
+ device = torch.device("cpu")
31
+ dtype = torch.float32
32
+
33
+ translator = Translator(
34
+ model_name_or_card="seamlessM4T_v2_large",
35
+ vocoder_name_or_card="vocoder_v2",
36
+ device=device,
37
+ dtype=dtype,
38
+ apply_mintox=True,
39
+ )
whl/seamless_communication-1.0.0-py3-none-any.whl ADDED
Binary file (202 kB). View file