Update language list

#1
by hysts HF staff - opened
Files changed (3) hide show
  1. app.py +48 -16
  2. lang_list.py +254 -0
  3. mlg_config.json +0 -186
app.py CHANGED
@@ -1,4 +1,3 @@
1
- import json
2
  import os
3
 
4
  import gradio as gr
@@ -7,11 +6,15 @@ import torch
7
  import torchaudio
8
  from seamless_communication.models.inference.translator import Translator
9
 
10
- DESCRIPTION = "# SeamlessM4T"
 
 
 
 
 
 
11
 
12
- with open("./mlg_config.json", "r") as f:
13
- lang_idx_map = json.loads(f.read())
14
- LANGUAGES = lang_idx_map["multilingual"].keys()
15
 
16
  TASK_NAMES = [
17
  "S2ST (Speech to Speech translation)",
@@ -24,6 +27,8 @@ TASK_NAMES = [
24
  AUDIO_SAMPLE_RATE = 16000.0
25
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
26
 
 
 
27
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
28
  translator = Translator(
29
  model_name_or_card="multitask_unity_large",
@@ -43,6 +48,9 @@ def predict(
43
  target_language: str,
44
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
45
  task_name = task_name.split()[0]
 
 
 
46
  if task_name in ["S2ST", "S2TT", "ASR"]:
47
  if audio_source == "microphone":
48
  input_data = input_audio_mic
@@ -61,8 +69,8 @@ def predict(
61
  text_out, wav, sr = translator.predict(
62
  input=input_data,
63
  task_str=task_name,
64
- tgt_lang=target_language,
65
- src_lang=source_language,
66
  )
67
  if task_name in ["S2ST", "T2ST"]:
68
  return (sr, wav.cpu().detach().numpy()), text_out
@@ -80,26 +88,50 @@ def update_audio_ui(audio_source: str) -> tuple[dict, dict]:
80
 
81
  def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
82
  task_name = task_name.split()[0]
83
- if task_name in ["S2ST", "S2TT"]:
84
  return (
85
  gr.update(visible=True), # audio_box
86
  gr.update(visible=False), # input_text
87
  gr.update(visible=False), # source_language
88
- gr.update(visible=True), # target_language
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
89
  )
90
- elif task_name in ["T2ST", "T2TT"]:
91
  return (
92
  gr.update(visible=False), # audio_box
93
  gr.update(visible=True), # input_text
94
  gr.update(visible=True), # source_language
95
- gr.update(visible=True), # target_language
 
 
96
  )
97
  elif task_name == "ASR":
98
  return (
99
  gr.update(visible=True), # audio_box
100
  gr.update(visible=False), # input_text
101
  gr.update(visible=False), # source_language
102
- gr.update(visible=True), # target_language
 
 
103
  )
104
  else:
105
  raise ValueError(f"Unknown task: {task_name}")
@@ -137,14 +169,14 @@ with gr.Blocks(css="style.css") as demo:
137
  with gr.Row():
138
  source_language = gr.Dropdown(
139
  label="Source language",
140
- choices=LANGUAGES,
141
- value="eng",
142
  visible=False,
143
  )
144
  target_language = gr.Dropdown(
145
  label="Target language",
146
- choices=LANGUAGES,
147
- value="fra",
148
  )
149
  with gr.Row() as audio_box:
150
  audio_source = gr.Radio(
 
 
1
  import os
2
 
3
  import gradio as gr
 
6
  import torchaudio
7
  from seamless_communication.models.inference.translator import Translator
8
 
9
+ from lang_list import (
10
+ LANGUAGE_NAME_TO_CODE,
11
+ S2ST_TARGET_LANGUAGE_NAMES,
12
+ S2TT_TARGET_LANGUAGE_NAMES,
13
+ T2TT_TARGET_LANGUAGE_NAMES,
14
+ TEXT_SOURCE_LANGUAGE_NAMES,
15
+ )
16
 
17
+ DESCRIPTION = "# SeamlessM4T"
 
 
18
 
19
  TASK_NAMES = [
20
  "S2ST (Speech to Speech translation)",
 
27
  AUDIO_SAMPLE_RATE = 16000.0
28
  MAX_INPUT_AUDIO_LENGTH = 60 # in seconds
29
 
30
+ DEFAULT_TARGET_LANGUAGE = "French"
31
+
32
  device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
33
  translator = Translator(
34
  model_name_or_card="multitask_unity_large",
 
48
  target_language: str,
49
  ) -> tuple[tuple[int, np.ndarray] | None, str]:
50
  task_name = task_name.split()[0]
51
+ source_language_code = LANGUAGE_NAME_TO_CODE[source_language]
52
+ target_language_code = LANGUAGE_NAME_TO_CODE[target_language]
53
+
54
  if task_name in ["S2ST", "S2TT", "ASR"]:
55
  if audio_source == "microphone":
56
  input_data = input_audio_mic
 
69
  text_out, wav, sr = translator.predict(
70
  input=input_data,
71
  task_str=task_name,
72
+ tgt_lang=target_language_code,
73
+ src_lang=source_language_code,
74
  )
75
  if task_name in ["S2ST", "T2ST"]:
76
  return (sr, wav.cpu().detach().numpy()), text_out
 
88
 
89
  def update_input_ui(task_name: str) -> tuple[dict, dict, dict, dict]:
90
  task_name = task_name.split()[0]
91
+ if task_name == "S2ST":
92
  return (
93
  gr.update(visible=True), # audio_box
94
  gr.update(visible=False), # input_text
95
  gr.update(visible=False), # source_language
96
+ gr.update(
97
+ visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
98
+ ), # target_language
99
+ )
100
+ elif task_name == "S2TT":
101
+ return (
102
+ gr.update(visible=True), # audio_box
103
+ gr.update(visible=False), # input_text
104
+ gr.update(visible=False), # source_language
105
+ gr.update(
106
+ visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
107
+ ), # target_language
108
+ )
109
+ elif task_name == "T2ST":
110
+ return (
111
+ gr.update(visible=False), # audio_box
112
+ gr.update(visible=True), # input_text
113
+ gr.update(visible=True), # source_language
114
+ gr.update(
115
+ visible=True, choices=S2ST_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
116
+ ), # target_language
117
  )
118
+ elif task_name == "T2TT":
119
  return (
120
  gr.update(visible=False), # audio_box
121
  gr.update(visible=True), # input_text
122
  gr.update(visible=True), # source_language
123
+ gr.update(
124
+ visible=True, choices=T2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
125
+ ), # target_language
126
  )
127
  elif task_name == "ASR":
128
  return (
129
  gr.update(visible=True), # audio_box
130
  gr.update(visible=False), # input_text
131
  gr.update(visible=False), # source_language
132
+ gr.update(
133
+ visible=True, choices=S2TT_TARGET_LANGUAGE_NAMES, value=DEFAULT_TARGET_LANGUAGE
134
+ ), # target_language
135
  )
136
  else:
137
  raise ValueError(f"Unknown task: {task_name}")
 
169
  with gr.Row():
170
  source_language = gr.Dropdown(
171
  label="Source language",
172
+ choices=TEXT_SOURCE_LANGUAGE_NAMES,
173
+ value="English",
174
  visible=False,
175
  )
176
  target_language = gr.Dropdown(
177
  label="Target language",
178
+ choices=S2ST_TARGET_LANGUAGE_NAMES,
179
+ value=DEFAULT_TARGET_LANGUAGE,
180
  )
181
  with gr.Row() as audio_box:
182
  audio_source = gr.Radio(
lang_list.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Language dict
2
+ language_code_to_name = {
3
+ "afr": "Afrikaans",
4
+ "amh": "Amharic",
5
+ "arb": "Modern Standard Arabic",
6
+ "ary": "Moroccan Arabic",
7
+ "arz": "Egyptian Arabic",
8
+ "asm": "Assamese",
9
+ "ast": "Asturian",
10
+ "azj": "North Azerbaijani",
11
+ "bel": "Belarusian",
12
+ "ben": "Bengali",
13
+ "bos": "Bosnian",
14
+ "bul": "Bulgarian",
15
+ "cat": "Catalan",
16
+ "ceb": "Cebuano",
17
+ "ces": "Czech",
18
+ "ckb": "Central Kurdish",
19
+ "cmn": "Mandarin Chinese",
20
+ "cym": "Welsh",
21
+ "dan": "Danish",
22
+ "deu": "German",
23
+ "ell": "Greek",
24
+ "eng": "English",
25
+ "est": "Estonian",
26
+ "eus": "Basque",
27
+ "fin": "Finnish",
28
+ "fra": "French",
29
+ "gaz": "West Central Oromo",
30
+ "gle": "Irish",
31
+ "glg": "Galician",
32
+ "guj": "Gujarati",
33
+ "heb": "Hebrew",
34
+ "hin": "Hindi",
35
+ "hrv": "Croatian",
36
+ "hun": "Hungarian",
37
+ "hye": "Armenian",
38
+ "ibo": "Igbo",
39
+ "ind": "Indonesian",
40
+ "isl": "Icelandic",
41
+ "ita": "Italian",
42
+ "jav": "Javanese",
43
+ "jpn": "Japanese",
44
+ "kam": "Kamba",
45
+ "kan": "Kannada",
46
+ "kat": "Georgian",
47
+ "kaz": "Kazakh",
48
+ "kea": "Kabuverdianu",
49
+ "khk": "Halh Mongolian",
50
+ "khm": "Khmer",
51
+ "kir": "Kyrgyz",
52
+ "kor": "Korean",
53
+ "lao": "Lao",
54
+ "lit": "Lithuanian",
55
+ "ltz": "Luxembourgish",
56
+ "lug": "Ganda",
57
+ "luo": "Luo",
58
+ "lvs": "Standard Latvian",
59
+ "mai": "Maithili",
60
+ "mal": "Malayalam",
61
+ "mar": "Marathi",
62
+ "mkd": "Macedonian",
63
+ "mlt": "Maltese",
64
+ "mni": "Meitei",
65
+ "mya": "Burmese",
66
+ "nld": "Dutch",
67
+ "nno": "Norwegian Nynorsk",
68
+ "nob": "Norwegian Bokm\u00e5l",
69
+ "npi": "Nepali",
70
+ "nya": "Nyanja",
71
+ "oci": "Occitan",
72
+ "ory": "Odia",
73
+ "pan": "Punjabi",
74
+ "pbt": "Southern Pashto",
75
+ "pes": "Western Persian",
76
+ "pol": "Polish",
77
+ "por": "Portuguese",
78
+ "ron": "Romanian",
79
+ "rus": "Russian",
80
+ "slk": "Slovak",
81
+ "slv": "Slovenian",
82
+ "sna": "Shona",
83
+ "snd": "Sindhi",
84
+ "som": "Somali",
85
+ "spa": "Spanish",
86
+ "srp": "Serbian",
87
+ "swe": "Swedish",
88
+ "swh": "Swahili",
89
+ "tam": "Tamil",
90
+ "tel": "Telugu",
91
+ "tgk": "Tajik",
92
+ "tgl": "Tagalog",
93
+ "tha": "Thai",
94
+ "tur": "Turkish",
95
+ "ukr": "Ukrainian",
96
+ "urd": "Urdu",
97
+ "uzn": "Northern Uzbek",
98
+ "vie": "Vietnamese",
99
+ "xho": "Xhosa",
100
+ "yor": "Yoruba",
101
+ "yue": "Cantonese",
102
+ "zlm": "Colloquial Malay",
103
+ "zsm": "Standard Malay",
104
+ "zul": "Zulu",
105
+ }
106
+ LANGUAGE_NAME_TO_CODE = {v: k for k, v in language_code_to_name.items()}
107
+
108
+ # Source langs: S2ST / S2TT / ASR don't need source lang
109
+ # T2TT / T2ST use this
110
+ text_source_language_codes = [
111
+ "afr",
112
+ "amh",
113
+ "arb",
114
+ "ary",
115
+ "arz",
116
+ "asm",
117
+ "azj",
118
+ "bel",
119
+ "ben",
120
+ "bos",
121
+ "bul",
122
+ "cat",
123
+ "ceb",
124
+ "ces",
125
+ "ckb",
126
+ "cmn",
127
+ "cym",
128
+ "dan",
129
+ "deu",
130
+ "ell",
131
+ "eng",
132
+ "est",
133
+ "eus",
134
+ "fin",
135
+ "fra",
136
+ "gaz",
137
+ "gle",
138
+ "glg",
139
+ "guj",
140
+ "heb",
141
+ "hin",
142
+ "hrv",
143
+ "hun",
144
+ "hye",
145
+ "ibo",
146
+ "ind",
147
+ "isl",
148
+ "ita",
149
+ "jav",
150
+ "jpn",
151
+ "kan",
152
+ "kat",
153
+ "kaz",
154
+ "khk",
155
+ "khm",
156
+ "kir",
157
+ "kor",
158
+ "lao",
159
+ "lit",
160
+ "lug",
161
+ "luo",
162
+ "lvs",
163
+ "mai",
164
+ "mal",
165
+ "mar",
166
+ "mkd",
167
+ "mlt",
168
+ "mni",
169
+ "mya",
170
+ "nld",
171
+ "nno",
172
+ "nob",
173
+ "npi",
174
+ "nya",
175
+ "ory",
176
+ "pan",
177
+ "pbt",
178
+ "pes",
179
+ "pol",
180
+ "por",
181
+ "ron",
182
+ "rus",
183
+ "slk",
184
+ "slv",
185
+ "sna",
186
+ "snd",
187
+ "som",
188
+ "spa",
189
+ "srp",
190
+ "swe",
191
+ "swh",
192
+ "tam",
193
+ "tel",
194
+ "tgk",
195
+ "tgl",
196
+ "tha",
197
+ "tur",
198
+ "ukr",
199
+ "urd",
200
+ "uzn",
201
+ "vie",
202
+ "yor",
203
+ "yue",
204
+ "zsm",
205
+ "zul",
206
+ ]
207
+ TEXT_SOURCE_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in text_source_language_codes])
208
+
209
+ # Target langs:
210
+ # S2ST / T2ST
211
+ s2st_target_language_codes = [
212
+ "eng",
213
+ "arb",
214
+ "ben",
215
+ "cat",
216
+ "ces",
217
+ "cmn",
218
+ "cym",
219
+ "dan",
220
+ "deu",
221
+ "est",
222
+ "fin",
223
+ "fra",
224
+ "hin",
225
+ "ind",
226
+ "ita",
227
+ "jpn",
228
+ "kor",
229
+ "mlt",
230
+ "nld",
231
+ "pes",
232
+ "pol",
233
+ "por",
234
+ "ron",
235
+ "rus",
236
+ "slk",
237
+ "spa",
238
+ "swe",
239
+ "swh",
240
+ "tel",
241
+ "tgl",
242
+ "tha",
243
+ "tur",
244
+ "ukr",
245
+ "urd",
246
+ "uzn",
247
+ "vie",
248
+ ]
249
+ S2ST_TARGET_LANGUAGE_NAMES = sorted([language_code_to_name[code] for code in s2st_target_language_codes])
250
+
251
+ # S2TT / ASR
252
+ S2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
253
+ # T2TT
254
+ T2TT_TARGET_LANGUAGE_NAMES = TEXT_SOURCE_LANGUAGE_NAMES
mlg_config.json DELETED
@@ -1,186 +0,0 @@
1
- {
2
- "multilingual": {
3
- "arb": 0,
4
- "ben": 1,
5
- "cat": 2,
6
- "ces": 3,
7
- "cmn": 4,
8
- "cym": 5,
9
- "dan": 6,
10
- "deu": 7,
11
- "eng": 8,
12
- "est": 9,
13
- "fin": 10,
14
- "fra": 11,
15
- "hin": 12,
16
- "ind": 13,
17
- "ita": 14,
18
- "jpn": 15,
19
- "kor": 16,
20
- "mlt": 17,
21
- "nld": 18,
22
- "pes": 19,
23
- "pol": 20,
24
- "por": 21,
25
- "ron": 22,
26
- "rus": 23,
27
- "slk": 24,
28
- "spa": 25,
29
- "swe": 26,
30
- "swh": 27,
31
- "tel": 28,
32
- "tgl": 29,
33
- "tha": 30,
34
- "tur": 31,
35
- "ukr": 32,
36
- "urd": 33,
37
- "uzn": 34,
38
- "vie": 35
39
- },
40
- "multispkr": {
41
- "arb": [
42
- 0
43
- ],
44
- "ben": [
45
- 2,
46
- 1
47
- ],
48
- "cat": [
49
- 3
50
- ],
51
- "ces": [
52
- 4
53
- ],
54
- "cmn": [
55
- 5
56
- ],
57
- "cym": [
58
- 6
59
- ],
60
- "dan": [
61
- 7,
62
- 8
63
- ],
64
- "deu": [
65
- 9
66
- ],
67
- "eng": [
68
- 10
69
- ],
70
- "est": [
71
- 11,
72
- 12,
73
- 13
74
- ],
75
- "fin": [
76
- 14
77
- ],
78
- "fra": [
79
- 15
80
- ],
81
- "hin": [
82
- 16
83
- ],
84
- "ind": [
85
- 17,
86
- 24,
87
- 18,
88
- 20,
89
- 19,
90
- 21,
91
- 23,
92
- 27,
93
- 26,
94
- 22,
95
- 25
96
- ],
97
- "ita": [
98
- 29,
99
- 28
100
- ],
101
- "jpn": [
102
- 30
103
- ],
104
- "kor": [
105
- 31
106
- ],
107
- "mlt": [
108
- 32,
109
- 33,
110
- 34
111
- ],
112
- "nld": [
113
- 35
114
- ],
115
- "pes": [
116
- 36
117
- ],
118
- "pol": [
119
- 37
120
- ],
121
- "por": [
122
- 38
123
- ],
124
- "ron": [
125
- 39
126
- ],
127
- "rus": [
128
- 40
129
- ],
130
- "slk": [
131
- 41
132
- ],
133
- "spa": [
134
- 42
135
- ],
136
- "swe": [
137
- 43,
138
- 45,
139
- 44
140
- ],
141
- "swh": [
142
- 46,
143
- 48,
144
- 47
145
- ],
146
- "tel": [
147
- 49
148
- ],
149
- "tgl": [
150
- 50
151
- ],
152
- "tha": [
153
- 51,
154
- 54,
155
- 55,
156
- 52,
157
- 53
158
- ],
159
- "tur": [
160
- 58,
161
- 57,
162
- 56
163
- ],
164
- "ukr": [
165
- 59
166
- ],
167
- "urd": [
168
- 60,
169
- 61,
170
- 62
171
- ],
172
- "uzn": [
173
- 63,
174
- 64,
175
- 65
176
- ],
177
- "vie": [
178
- 66,
179
- 67,
180
- 70,
181
- 71,
182
- 68,
183
- 69
184
- ]
185
- }
186
- }