avans06 commited on
Commit
90c1d05
1 Parent(s): 4fbf19d

Added a web UI page for Translation to test the translation capabilities of the model.

Browse files
Files changed (3) hide show
  1. app.py +172 -51
  2. config.json5 +6 -0
  3. docs/translateModel.md +1 -0
app.py CHANGED
@@ -42,6 +42,7 @@ from src.translation.translationLangs import (TranslationLang,
42
  _TO_LANG_CODE_WHISPER, sort_lang_by_whisper_codes,
43
  get_lang_from_whisper_name, get_lang_from_whisper_code, get_lang_from_nllb_name, get_lang_from_m2m100_name, get_lang_from_seamlessTx_name,
44
  get_lang_whisper_names, get_lang_nllb_names, get_lang_m2m100_names, get_lang_seamlessTx_names)
 
45
  import shutil
46
  import zhconv
47
  import tqdm
@@ -214,26 +215,6 @@ class WhisperTranscriber:
214
  whisperModelName: str = decodeOptions.pop("whisperModelName")
215
  whisperLangName: str = decodeOptions.pop("whisperLangName")
216
 
217
- translateInput: str = decodeOptions.pop("translateInput")
218
- m2m100ModelName: str = decodeOptions.pop("m2m100ModelName")
219
- m2m100LangName: str = decodeOptions.pop("m2m100LangName")
220
- nllbModelName: str = decodeOptions.pop("nllbModelName")
221
- nllbLangName: str = decodeOptions.pop("nllbLangName")
222
- mt5ModelName: str = decodeOptions.pop("mt5ModelName")
223
- mt5LangName: str = decodeOptions.pop("mt5LangName")
224
- ALMAModelName: str = decodeOptions.pop("ALMAModelName")
225
- ALMALangName: str = decodeOptions.pop("ALMALangName")
226
- madlad400ModelName: str = decodeOptions.pop("madlad400ModelName")
227
- madlad400LangName: str = decodeOptions.pop("madlad400LangName")
228
- seamlessModelName: str = decodeOptions.pop("seamlessModelName")
229
- seamlessLangName: str = decodeOptions.pop("seamlessLangName")
230
-
231
- translationBatchSize: int = decodeOptions.pop("translationBatchSize")
232
- translationNoRepeatNgramSize: int = decodeOptions.pop("translationNoRepeatNgramSize")
233
- translationNumBeams: int = decodeOptions.pop("translationNumBeams")
234
- translationTorchDtypeFloat16: bool = decodeOptions.pop("translationTorchDtypeFloat16")
235
- translationUsingBitsandbytes: str = decodeOptions.pop("translationUsingBitsandbytes")
236
-
237
  sourceInput: str = decodeOptions.pop("sourceInput")
238
  urlData: str = decodeOptions.pop("urlData")
239
  multipleFiles: List = decodeOptions.pop("multipleFiles")
@@ -346,36 +327,7 @@ class WhisperTranscriber:
346
  cache=self.model_cache, models=self.app_config.models["whisper"])
347
 
348
  progress(0, desc="init translate model")
349
- translationLang = None
350
- translationModel = None
351
- if translateInput == "m2m100" and m2m100LangName is not None and len(m2m100LangName) > 0:
352
- selectedModelName = m2m100ModelName if m2m100ModelName is not None and len(m2m100ModelName) > 0 else "m2m100_418M/facebook"
353
- selectedModel = next((modelConfig for modelConfig in self.app_config.models["m2m100"] if modelConfig.name == selectedModelName), None)
354
- translationLang = get_lang_from_m2m100_name(m2m100LangName)
355
- elif translateInput == "nllb" and nllbLangName is not None and len(nllbLangName) > 0:
356
- selectedModelName = nllbModelName if nllbModelName is not None and len(nllbModelName) > 0 else "nllb-200-distilled-600M/facebook"
357
- selectedModel = next((modelConfig for modelConfig in self.app_config.models["nllb"] if modelConfig.name == selectedModelName), None)
358
- translationLang = get_lang_from_nllb_name(nllbLangName)
359
- elif translateInput == "mt5" and mt5LangName is not None and len(mt5LangName) > 0:
360
- selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
361
- selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
362
- translationLang = get_lang_from_m2m100_name(mt5LangName)
363
- elif translateInput == "ALMA" and ALMALangName is not None and len(ALMALangName) > 0:
364
- selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-7B-ct2:int8_float16/avan"
365
- selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
366
- translationLang = get_lang_from_m2m100_name(ALMALangName)
367
- elif translateInput == "madlad400" and madlad400LangName is not None and len(madlad400LangName) > 0:
368
- selectedModelName = madlad400ModelName if madlad400ModelName is not None and len(madlad400ModelName) > 0 else "madlad400-3b-mt-ct2-int8_float16/SoybeanMilk"
369
- selectedModel = next((modelConfig for modelConfig in self.app_config.models["madlad400"] if modelConfig.name == selectedModelName), None)
370
- translationLang = get_lang_from_m2m100_name(madlad400LangName)
371
- elif translateInput == "seamless" and seamlessLangName is not None and len(seamlessLangName) > 0:
372
- selectedModelName = seamlessModelName if seamlessModelName is not None and len(seamlessModelName) > 0 else "seamless-m4t-v2-large/facebook"
373
- selectedModel = next((modelConfig for modelConfig in self.app_config.models["seamless"] if modelConfig.name == selectedModelName), None)
374
- translationLang = get_lang_from_seamlessTx_name(seamlessLangName)
375
-
376
-
377
- if translationLang is not None:
378
- translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=whisperLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams, torchDtypeFloat16=translationTorchDtypeFloat16, usingBitsandbytes=translationUsingBitsandbytes)
379
 
380
  progress(0, desc="init transcribe")
381
  # Result
@@ -871,6 +823,123 @@ class WhisperTranscriber:
871
  self.diarization.cleanup()
872
  self.diarization = None
873
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
874
  def create_ui(app_config: ApplicationConfig):
875
  translateModelMd: str = None
876
  optionsMd: str = None
@@ -1135,11 +1204,63 @@ def create_ui(app_config: ApplicationConfig):
1135
 
1136
  return transcribe
1137
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1138
  simpleTranscribe = create_transcribe(uiDescription, is_queue_mode)
1139
  fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
1140
  fullTranscribe = create_transcribe(fullDescription, is_queue_mode, True)
 
1141
 
1142
- demo = gr.TabbedInterface([simpleTranscribe, fullTranscribe], tab_names=["Simple", "Full"], css=css)
1143
 
1144
  # Queue up the demo
1145
  if is_queue_mode:
 
42
  _TO_LANG_CODE_WHISPER, sort_lang_by_whisper_codes,
43
  get_lang_from_whisper_name, get_lang_from_whisper_code, get_lang_from_nllb_name, get_lang_from_m2m100_name, get_lang_from_seamlessTx_name,
44
  get_lang_whisper_names, get_lang_nllb_names, get_lang_m2m100_names, get_lang_seamlessTx_names)
45
+ import re
46
  import shutil
47
  import zhconv
48
  import tqdm
 
215
  whisperModelName: str = decodeOptions.pop("whisperModelName")
216
  whisperLangName: str = decodeOptions.pop("whisperLangName")
217
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
218
  sourceInput: str = decodeOptions.pop("sourceInput")
219
  urlData: str = decodeOptions.pop("urlData")
220
  multipleFiles: List = decodeOptions.pop("multipleFiles")
 
327
  cache=self.model_cache, models=self.app_config.models["whisper"])
328
 
329
  progress(0, desc="init translate model")
330
+ translationLang, translationModel = self.initTranslationModel(whisperLangName, whisperLang, decodeOptions)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
331
 
332
  progress(0, desc="init transcribe")
333
  # Result
 
823
  self.diarization.cleanup()
824
  self.diarization = None
825
 
826
+ # Entry function for the simple or full tab, Queue mode disabled: progress bars will not be shown
827
+ def translation_entry(self, data: dict): return self.translation_entry_progress(data)
828
+
829
+ # Entry function for the simple or full tab with progress, Progress tracking requires queuing to be enabled
830
+ def translation_entry_progress(self, data: dict, progress=gr.Progress()):
831
+ dataDict = {}
832
+ for key, value in data.items():
833
+ dataDict.update({key.elem_id: value})
834
+
835
+ return self.translation_webui(dataDict, progress=progress)
836
+
837
+ def translation_webui(self, dataDict: dict, progress: gr.Progress = None):
838
+ try:
839
+ inputText: str = dataDict.pop("inputText")
840
+ inputLangName: str = dataDict.pop("inputLangName")
841
+ inputLang: TranslationLang = get_lang_from_whisper_name(inputLangName)
842
+
843
+ progress(0, desc="init translate model")
844
+ translationLang, translationModel = self.initTranslationModel(inputLangName, inputLang, dataDict)
845
+
846
+ result = []
847
+ if translationModel and translationModel.translationLang:
848
+ try:
849
+ inputTexts = inputText.split("\n")
850
+
851
+ progress(0, desc="Translation starting...")
852
+
853
+ perf_start_time = time.perf_counter()
854
+ translationModel.load_model()
855
+ for idx, text in enumerate(tqdm.tqdm(inputTexts)):
856
+ if not text or re.match("""^[\u2000-\u206F\u2E00-\u2E7F\\'!"#$%&()*+,\-.\/:;<=>?@\[\]^_`{|}~\d ]+$""", text.strip()):
857
+ result.append(text)
858
+ else:
859
+ result.append(translationModel.translation(text))
860
+ progress((idx+1)/len(inputTexts), desc=f"Process inputText: {idx+1}/{len(inputTexts)}")
861
+
862
+ translationModel.release_vram()
863
+ perf_end_time = time.perf_counter()
864
+ # Call the finished callback
865
+ progress(1, desc=f"Process inputText: {idx+1}/{len(inputTexts)}")
866
+
867
+ print("\n\nprocess inputText took {} seconds.\n\n".format(perf_end_time - perf_start_time))
868
+ except Exception as e:
869
+ print(traceback.format_exc())
870
+ print("Error process inputText: " + str(e))
871
+
872
+ resultStr = "\n".join(result)
873
+
874
+ translationZho: bool = translationModel and translationModel.translationLang and translationModel.translationLang.nllb and translationModel.translationLang.nllb.code in ["zho_Hant", "zho_Hans", "yue_Hant"]
875
+ if translationZho:
876
+ if translationModel.translationLang.nllb.code == "zho_Hant":
877
+ locale = "zh-tw"
878
+ elif translationModel.translationLang.nllb.code == "zho_Hans":
879
+ locale = "zh-cn"
880
+ elif translationModel.translationLang.nllb.code == "yue_Hant":
881
+ locale = "zh-hk"
882
+ resultStr = zhconv.convert(resultStr, locale)
883
+
884
+ return resultStr
885
+ except Exception as e:
886
+ print(traceback.format_exc())
887
+ return "Error occurred during transcribe: " + str(e) + "\n\n" + traceback.format_exc()
888
+
889
+ def initTranslationModel(self, inputLangName: str, inputLang: TranslationLang, dataDict: dict):
890
+ translateInput: str = dataDict.pop("translateInput")
891
+ m2m100ModelName: str = dataDict.pop("m2m100ModelName")
892
+ m2m100LangName: str = dataDict.pop("m2m100LangName")
893
+ nllbModelName: str = dataDict.pop("nllbModelName")
894
+ nllbLangName: str = dataDict.pop("nllbLangName")
895
+ mt5ModelName: str = dataDict.pop("mt5ModelName")
896
+ mt5LangName: str = dataDict.pop("mt5LangName")
897
+ ALMAModelName: str = dataDict.pop("ALMAModelName")
898
+ ALMALangName: str = dataDict.pop("ALMALangName")
899
+ madlad400ModelName: str = dataDict.pop("madlad400ModelName")
900
+ madlad400LangName: str = dataDict.pop("madlad400LangName")
901
+ seamlessModelName: str = dataDict.pop("seamlessModelName")
902
+ seamlessLangName: str = dataDict.pop("seamlessLangName")
903
+
904
+ translationBatchSize: int = dataDict.pop("translationBatchSize")
905
+ translationNoRepeatNgramSize: int = dataDict.pop("translationNoRepeatNgramSize")
906
+ translationNumBeams: int = dataDict.pop("translationNumBeams")
907
+ translationTorchDtypeFloat16: bool = dataDict.pop("translationTorchDtypeFloat16")
908
+ translationUsingBitsandbytes: str = dataDict.pop("translationUsingBitsandbytes")
909
+
910
+ translationLang = None
911
+ translationModel = None
912
+ if translateInput == "m2m100" and m2m100LangName is not None and len(m2m100LangName) > 0:
913
+ selectedModelName = m2m100ModelName if m2m100ModelName is not None and len(m2m100ModelName) > 0 else "m2m100_418M/facebook"
914
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["m2m100"] if modelConfig.name == selectedModelName), None)
915
+ translationLang = get_lang_from_m2m100_name(m2m100LangName)
916
+ elif translateInput == "nllb" and nllbLangName is not None and len(nllbLangName) > 0:
917
+ selectedModelName = nllbModelName if nllbModelName is not None and len(nllbModelName) > 0 else "nllb-200-distilled-600M/facebook"
918
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["nllb"] if modelConfig.name == selectedModelName), None)
919
+ translationLang = get_lang_from_nllb_name(nllbLangName)
920
+ elif translateInput == "mt5" and mt5LangName is not None and len(mt5LangName) > 0:
921
+ selectedModelName = mt5ModelName if mt5ModelName is not None and len(mt5ModelName) > 0 else "mt5-zh-ja-en-trimmed/K024"
922
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["mt5"] if modelConfig.name == selectedModelName), None)
923
+ translationLang = get_lang_from_m2m100_name(mt5LangName)
924
+ elif translateInput == "ALMA" and ALMALangName is not None and len(ALMALangName) > 0:
925
+ selectedModelName = ALMAModelName if ALMAModelName is not None and len(ALMAModelName) > 0 else "ALMA-7B-ct2:int8_float16/avan"
926
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["ALMA"] if modelConfig.name == selectedModelName), None)
927
+ translationLang = get_lang_from_m2m100_name(ALMALangName)
928
+ elif translateInput == "madlad400" and madlad400LangName is not None and len(madlad400LangName) > 0:
929
+ selectedModelName = madlad400ModelName if madlad400ModelName is not None and len(madlad400ModelName) > 0 else "madlad400-3b-mt-ct2-int8_float16/SoybeanMilk"
930
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["madlad400"] if modelConfig.name == selectedModelName), None)
931
+ translationLang = get_lang_from_m2m100_name(madlad400LangName)
932
+ elif translateInput == "seamless" and seamlessLangName is not None and len(seamlessLangName) > 0:
933
+ selectedModelName = seamlessModelName if seamlessModelName is not None and len(seamlessModelName) > 0 else "seamless-m4t-v2-large/facebook"
934
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["seamless"] if modelConfig.name == selectedModelName), None)
935
+ translationLang = get_lang_from_seamlessTx_name(seamlessLangName)
936
+
937
+ if translationLang is not None:
938
+ translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=inputLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams, torchDtypeFloat16=translationTorchDtypeFloat16, usingBitsandbytes=translationUsingBitsandbytes)
939
+
940
+ return translationLang, translationModel
941
+
942
+
943
  def create_ui(app_config: ApplicationConfig):
944
  translateModelMd: str = None
945
  optionsMd: str = None
 
1204
 
1205
  return transcribe
1206
 
1207
+ def create_translation(isQueueMode: bool):
1208
+ with gr.Blocks() as translation:
1209
+ translateInput = gr.State(value="m2m100", elem_id = "translateInput")
1210
+ with gr.Row():
1211
+ with gr.Column():
1212
+ submitBtn = gr.Button("Submit", variant="primary")
1213
+ with gr.Column():
1214
+ with gr.Tab(label="M2M100") as m2m100Tab:
1215
+ with gr.Row():
1216
+ inputDict = common_m2m100_inputs()
1217
+ with gr.Tab(label="NLLB") as nllbTab:
1218
+ with gr.Row():
1219
+ inputDict.update(common_nllb_inputs())
1220
+ with gr.Tab(label="MT5") as mt5Tab:
1221
+ with gr.Row():
1222
+ inputDict.update(common_mt5_inputs())
1223
+ with gr.Tab(label="ALMA") as almaTab:
1224
+ with gr.Row():
1225
+ inputDict.update(common_ALMA_inputs())
1226
+ with gr.Tab(label="madlad400") as madlad400Tab:
1227
+ with gr.Row():
1228
+ inputDict.update(common_madlad400_inputs())
1229
+ with gr.Tab(label="seamless") as seamlessTab:
1230
+ with gr.Row():
1231
+ inputDict.update(common_seamless_inputs())
1232
+ m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
1233
+ nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
1234
+ mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
1235
+ almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
1236
+ madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
1237
+ seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
1238
+ with gr.Column():
1239
+ inputDict.update({
1240
+ gr.Dropdown(label="Input - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="inputLangName"),
1241
+ gr.Text(lines=5, label="Input - Text", elem_id="inputText", elem_classes="scroll-show"),
1242
+ })
1243
+ with gr.Column():
1244
+ with gr.Accordion("Translation options", open=False):
1245
+ inputDict.update(common_translation_inputs())
1246
+ with gr.Column():
1247
+ outputs = [gr.Text(label="Translation Text", autoscroll=False, show_copy_button=True, interactive=True, elem_id="outputTranslationText", elem_classes="scroll-show"),]
1248
+ if translateModelMd is not None:
1249
+ with gr.Accordion("docs/translateModel.md", open=False):
1250
+ gr.Markdown(translateModelMd)
1251
+
1252
+ inputDict.update({translateInput})
1253
+ submitBtn.click(fn=ui.translation_entry_progress if isQueueMode else ui.translation_entry,
1254
+ inputs=inputDict, outputs=outputs)
1255
+
1256
+ return translation
1257
+
1258
  simpleTranscribe = create_transcribe(uiDescription, is_queue_mode)
1259
  fullDescription = uiDescription + "\n\n\n\n" + "Be careful when changing some of the options in the full interface - this can cause the model to crash."
1260
  fullTranscribe = create_transcribe(fullDescription, is_queue_mode, True)
1261
+ uiTranslation = create_translation(is_queue_mode)
1262
 
1263
+ demo = gr.TabbedInterface([simpleTranscribe, fullTranscribe, uiTranslation], tab_names=["Simple", "Full", "Translation"], css=css)
1264
 
1265
  # Queue up the demo
1266
  if is_queue_mode:
config.json5 CHANGED
@@ -248,6 +248,12 @@
248
  "type": "huggingface",
249
  "tokenizer_url": "jbochi/madlad400-3b-mt"
250
  },
 
 
 
 
 
 
251
  {
252
  "name": "madlad400-10b-mt-ct2-int8_float16/SoybeanMilk",
253
  "url": "SoybeanMilk/madlad400-10b-mt-ct2-int8_float16",
 
248
  "type": "huggingface",
249
  "tokenizer_url": "jbochi/madlad400-3b-mt"
250
  },
251
+ {
252
+ "name": "madlad400-7b-mt-bt-ct2-int8_float16/avan",
253
+ "url": "avans06/madlad400-7b-mt-bt-ct2-int8_float16",
254
+ "type": "huggingface",
255
+ "tokenizer_url": "jbochi/madlad400-7b-mt-bt"
256
+ },
257
  {
258
  "name": "madlad400-10b-mt-ct2-int8_float16/SoybeanMilk",
259
  "url": "SoybeanMilk/madlad400-10b-mt-ct2-int8_float16",
docs/translateModel.md CHANGED
@@ -129,6 +129,7 @@ madlad400 is a multilingual machine translation model based on the T5 architectu
129
  | Name | Parameters | Size | type/quantize | Required VRAM |
130
  |------|------------|------|---------------|---------------|
131
  | [SoybeanMilk/madlad400-3b-mt-ct2-int8_float16](https://huggingface.co/SoybeanMilk/madlad400-3b-mt-ct2-int8_float16) | 3B | 2.95 GB | int8_float16 | ≈2.7 GB |
 
132
  | [SoybeanMilk/madlad400-10b-mt-ct2-int8_float16](https://huggingface.co/SoybeanMilk/madlad400-10b-mt-ct2-int8_float16) | 10.7B | 10.7 GB | int8_float16 | ≈10 GB |
133
 
134
  ## SeamlessM4T
 
129
  | Name | Parameters | Size | type/quantize | Required VRAM |
130
  |------|------------|------|---------------|---------------|
131
  | [SoybeanMilk/madlad400-3b-mt-ct2-int8_float16](https://huggingface.co/SoybeanMilk/madlad400-3b-mt-ct2-int8_float16) | 3B | 2.95 GB | int8_float16 | ≈2.7 GB |
132
+ | [avans06/madlad400-7b-mt-bt-ct2-int8_float16](https://huggingface.co/avans06/madlad400-7b-mt-bt-ct2-int8_float16) | 7.2B | 8.31 GB | int8_float16 (finetuned on backtranslated data) | ≈8.5 GB |
133
  | [SoybeanMilk/madlad400-10b-mt-ct2-int8_float16](https://huggingface.co/SoybeanMilk/madlad400-10b-mt-ct2-int8_float16) | 10.7B | 10.7 GB | int8_float16 | ≈10 GB |
134
 
135
  ## SeamlessM4T