Maximofn commited on
Commit
46736da
1 Parent(s): 7d87fea

Script for translate concatenate transcription file and languajes list

Browse files
Files changed (2) hide show
  1. lang_list.py +175 -0
  2. translate_transcriptions.py +84 -0
lang_list.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Languages dict
2
+ LANGUAGE_NAME_TO_CODE = {
3
+ "العربية": "ar_AR",
4
+ "Čeština": "cs_CZ",
5
+ "Deutsch": "de_DE",
6
+ "English": "en_XX",
7
+ "Español": "es_XX",
8
+ "Eesti": "et_EE",
9
+ "Suomi": "fi_FI",
10
+ "Français": "fr_XX",
11
+ "ગુજરાતી": "gu_IN",
12
+ "हिन्दी": "hi_IN",
13
+ "Italiano": "it_IT",
14
+ "日本語": "ja_XX",
15
+ "Қазақ": "kk_KZ",
16
+ "한국어": "ko_KR",
17
+ "Lietuvių": "lt_LT",
18
+ "Latviešu": "lv_LV",
19
+ "ဗမာ": "my_MM",
20
+ "नेपाली": "ne_NP",
21
+ "Nederlands": "nl_XX",
22
+ "Română": "ro_RO",
23
+ "Русский": "ru_RU",
24
+ "සිංහල": "si_LK",
25
+ "Türkçe": "tr_TR",
26
+ "Tiếng Việt": "vi_VN",
27
+ "中文": "zh_CN",
28
+ "Afrikaans": "af_ZA",
29
+ "Azərbaycan": "az_AZ",
30
+ "বাংলা": "bn_IN",
31
+ "فارسی": "fa_IR",
32
+ "עברית": "he_IL",
33
+ "Hrvatski": "hr_HR",
34
+ "Indonesia": "id_ID",
35
+ "ქართული": "ka_GE",
36
+ "ខ្មែរ": "km_KH",
37
+ "Македонски": "mk_MK",
38
+ "മലയാളം": "ml_IN",
39
+ "Монгол": "mn_MN",
40
+ "मराठी": "mr_IN",
41
+ "Polski": "pl_PL",
42
+ "پښتو": "ps_AF",
43
+ "Português": "pt_XX",
44
+ "Svenska": "sv_SE",
45
+ "Kiswahili": "sw_KE",
46
+ "தமிழ்": "ta_IN",
47
+ "తెలుగు": "te_IN",
48
+ "ไทย": "th_TH",
49
+ "Tagalog": "tl_XX",
50
+ "Українська": "uk_UA",
51
+ "اردو": "ur_PK",
52
+ "isiXhosa": "xh_ZA",
53
+ "Galego": "gl_ES",
54
+ "Slovenščina": "sl_SI"
55
+ }
56
+
57
+ # Whisper languages dict
58
+ WHISPER_LANGUAGES = {
59
+ "en": "english",
60
+ "zh": "chinese",
61
+ "de": "german",
62
+ "es": "spanish",
63
+ "ru": "russian",
64
+ "ko": "korean",
65
+ "fr": "french",
66
+ "ja": "japanese",
67
+ "pt": "portuguese",
68
+ "tr": "turkish",
69
+ "pl": "polish",
70
+ "ca": "catalan",
71
+ "nl": "dutch",
72
+ "ar": "arabic",
73
+ "sv": "swedish",
74
+ "it": "italian",
75
+ "id": "indonesian",
76
+ "hi": "hindi",
77
+ "fi": "finnish",
78
+ "vi": "vietnamese",
79
+ "he": "hebrew",
80
+ "uk": "ukrainian",
81
+ "el": "greek",
82
+ "ms": "malay",
83
+ "cs": "czech",
84
+ "ro": "romanian",
85
+ "da": "danish",
86
+ "hu": "hungarian",
87
+ "ta": "tamil",
88
+ "no": "norwegian",
89
+ "th": "thai",
90
+ "ur": "urdu",
91
+ "hr": "croatian",
92
+ "bg": "bulgarian",
93
+ "lt": "lithuanian",
94
+ "la": "latin",
95
+ "mi": "maori",
96
+ "ml": "malayalam",
97
+ "cy": "welsh",
98
+ "sk": "slovak",
99
+ "te": "telugu",
100
+ "fa": "persian",
101
+ "lv": "latvian",
102
+ "bn": "bengali",
103
+ "sr": "serbian",
104
+ "az": "azerbaijani",
105
+ "sl": "slovenian",
106
+ "kn": "kannada",
107
+ "et": "estonian",
108
+ "mk": "macedonian",
109
+ "br": "breton",
110
+ "eu": "basque",
111
+ "is": "icelandic",
112
+ "hy": "armenian",
113
+ "ne": "nepali",
114
+ "mn": "mongolian",
115
+ "bs": "bosnian",
116
+ "kk": "kazakh",
117
+ "sq": "albanian",
118
+ "sw": "swahili",
119
+ "gl": "galician",
120
+ "mr": "marathi",
121
+ "pa": "punjabi",
122
+ "si": "sinhala",
123
+ "km": "khmer",
124
+ "sn": "shona",
125
+ "yo": "yoruba",
126
+ "so": "somali",
127
+ "af": "afrikaans",
128
+ "oc": "occitan",
129
+ "ka": "georgian",
130
+ "be": "belarusian",
131
+ "tg": "tajik",
132
+ "sd": "sindhi",
133
+ "gu": "gujarati",
134
+ "am": "amharic",
135
+ "yi": "yiddish",
136
+ "lo": "lao",
137
+ "uz": "uzbek",
138
+ "fo": "faroese",
139
+ "ht": "haitian creole",
140
+ "ps": "pashto",
141
+ "tk": "turkmen",
142
+ "nn": "nynorsk",
143
+ "mt": "maltese",
144
+ "sa": "sanskrit",
145
+ "lb": "luxembourgish",
146
+ "my": "myanmar",
147
+ "bo": "tibetan",
148
+ "tl": "tagalog",
149
+ "mg": "malagasy",
150
+ "as": "assamese",
151
+ "tt": "tatar",
152
+ "haw": "hawaiian",
153
+ "ln": "lingala",
154
+ "ha": "hausa",
155
+ "ba": "bashkir",
156
+ "jw": "javanese",
157
+ "su": "sundanese",
158
+ }
159
+
160
+ def union_language_dict():
161
+ # Create a dictionary to store the language codes
162
+ language_dict = {}
163
+ # Iterate over the LANGUAGE_NAME_TO_CODE dictionary
164
+ for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
165
+ # Extract the language code (the first two characters before the underscore)
166
+ lang_code = language_code.split('_')[0].lower()
167
+
168
+ # Check if the language code is present in WHISPER_LANGUAGES
169
+ if lang_code in WHISPER_LANGUAGES:
170
+ # Construct the entry for the resulting dictionary
171
+ language_dict[language_name] = {
172
+ "transcriber": lang_code,
173
+ "translator": language_code
174
+ }
175
+ return language_dict
translate_transcriptions.py ADDED
@@ -0,0 +1,84 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
3
+ from lang_list import LANGUAGE_NAME_TO_CODE, WHISPER_LANGUAGES
4
+ import argparse
5
+ import re
6
+
7
+ language_dict = {}
8
+ # Iterate over the LANGUAGE_NAME_TO_CODE dictionary
9
+ for language_name, language_code in LANGUAGE_NAME_TO_CODE.items():
10
+ # Extract the language code (the first two characters before the underscore)
11
+ lang_code = language_code.split('_')[0].lower()
12
+
13
+ # Check if the language code is present in WHISPER_LANGUAGES
14
+ if lang_code in WHISPER_LANGUAGES:
15
+ # Construct the entry for the resulting dictionary
16
+ language_dict[language_name] = {
17
+ "transcriber": lang_code,
18
+ "translator": language_code
19
+ }
20
+
21
+
22
+
23
+ def translate(transcribed_text, source_languaje, target_languaje, translate_model, translate_tokenizer, device="cpu"):
24
+ # Get source and target languaje codes
25
+ source_languaje_code = language_dict[source_languaje]["translator"]
26
+ target_languaje_code = language_dict[target_languaje]["translator"]
27
+
28
+ encoded = translate_tokenizer(transcribed_text, return_tensors="pt").to(device)
29
+ generated_tokens = translate_model.generate(
30
+ **encoded,
31
+ forced_bos_token_id=translate_tokenizer.lang_code_to_id[target_languaje_code]
32
+ )
33
+ translated = translate_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]
34
+
35
+ return translated
36
+
37
+ def main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device):
38
+ output_folder = "translated_transcriptions"
39
+ _, transcription_file_name = transcription_file.split("/")
40
+ transcription_file_name, _ = transcription_file_name.split(".")
41
+
42
+ # Read transcription
43
+ with open(transcription_file, "r") as f:
44
+ transcription = f.read().splitlines()
45
+
46
+ # Translate
47
+ translate_transcription = ""
48
+ for line in transcription:
49
+ if re.match(r"\d+$", line):
50
+ translate_transcription += f"{line}\n"
51
+ elif re.match(r"\d\d:\d\d:\d\d,\d\d\d --> \d\d:\d\d:\d\d,\d\d\d", line):
52
+ translate_transcription += f"{line}\n"
53
+ elif re.match(r"^$", line):
54
+ translate_transcription += f"{line}\n"
55
+ else:
56
+ translated = translate(line, source_languaje, target_languaje, translate_model, translate_tokenizer, device)
57
+ # translated = line
58
+ translate_transcription += f"{translated}\n"
59
+
60
+ # Save translation
61
+ output_file = f"{output_folder}/{transcription_file_name}_{target_languaje}.srt"
62
+ with open(output_file, "w") as f:
63
+ f.write(translate_transcription)
64
+
65
+ if __name__ == "__main__":
66
+ parser = argparse.ArgumentParser()
67
+ parser.add_argument("transcription_file", help="Transcribed text")
68
+ parser.add_argument("--source_languaje", type=str, required=True)
69
+ parser.add_argument("--target_languaje", type=str, required=True)
70
+ parser.add_argument("--device", type=str, default="cpu")
71
+ args = parser.parse_args()
72
+
73
+ transcription_file = args.transcription_file
74
+ source_languaje = args.source_languaje
75
+ target_languaje = args.target_languaje
76
+ device = args.device
77
+
78
+ # model
79
+ print("Loading translation model")
80
+ translate_model = MBartForConditionalGeneration.from_pretrained("facebook/mbart-large-50-many-to-many-mmt").to(device)
81
+ translate_tokenizer = MBart50TokenizerFast.from_pretrained("facebook/mbart-large-50-many-to-many-mmt")
82
+ print("Translation model loaded")
83
+
84
+ main(transcription_file, source_languaje, target_languaje, translate_model, translate_tokenizer, device)