avans06 commited on
Commit
61d82fd
·
1 Parent(s): ca3bee7

Add Meta-Llama-3-8B-Instruct ctranslate2 as the translation model to use.

Browse files
Files changed (4) hide show
  1. app.py +28 -9
  2. config.json5 +8 -0
  3. src/config.py +2 -2
  4. src/translation/translationModel.py +14 -2
app.py CHANGED
@@ -921,6 +921,8 @@ class WhisperTranscriber:
921
  madlad400LangName: str = dataDict.pop("madlad400LangName")
922
  seamlessModelName: str = dataDict.pop("seamlessModelName")
923
  seamlessLangName: str = dataDict.pop("seamlessLangName")
 
 
924
 
925
  translationBatchSize: int = dataDict.pop("translationBatchSize")
926
  translationNoRepeatNgramSize: int = dataDict.pop("translationNoRepeatNgramSize")
@@ -954,6 +956,10 @@ class WhisperTranscriber:
954
  selectedModelName = seamlessModelName if seamlessModelName is not None and len(seamlessModelName) > 0 else "seamless-m4t-v2-large/facebook"
955
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["seamless"] if modelConfig.name == selectedModelName), None)
956
  translationLang = get_lang_from_seamlessT_Tx_name(seamlessLangName)
 
 
 
 
957
 
958
  if translationLang is not None:
959
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=inputLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams, torchDtypeFloat16=translationTorchDtypeFloat16, usingBitsandbytes=translationUsingBitsandbytes)
@@ -1023,6 +1029,7 @@ def create_ui(app_config: ApplicationConfig):
1023
  ALMA_models = app_config.get_model_names("ALMA")
1024
  madlad400_models = app_config.get_model_names("madlad400")
1025
  seamless_models = app_config.get_model_names("seamless")
 
1026
  if not torch.cuda.is_available(): # Loading only quantized or models with medium-low parameters in an environment without GPU support.
1027
  nllb_models = list(filter(lambda nllb: any(name in nllb for name in ["-600M", "-1.3B", "-3.3B-ct2"]), nllb_models))
1028
  m2m100_models = list(filter(lambda m2m100: "12B" not in m2m100, m2m100_models))
@@ -1057,20 +1064,24 @@ def create_ui(app_config: ApplicationConfig):
1057
  gr.Dropdown(label="seamless - Model (for translate)", choices=seamless_models, elem_id="seamlessModelName"),
1058
  gr.Dropdown(label="seamless - Language", choices=sorted(get_lang_seamlessT_Tx_names()), elem_id="seamlessLangName"),
1059
  }
 
 
 
 
1060
 
1061
  common_translation_inputs = lambda : {
1062
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
1063
- gr.Number(label="Translation - No Repeat Ngram Size", precision=0, value=app_config.translation_no_repeat_ngram_size, elem_id="translationNoRepeatNgramSize"),
1064
- gr.Number(label="Translation - Num Beams", precision=0, value=app_config.translation_num_beams, elem_id="translationNumBeams"),
1065
  gr.Checkbox(label="Translation - Torch Dtype float16", visible=torch.cuda.is_available(), value=app_config.translation_torch_dtype_float16, info="Load the float32 translation model with float16 when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationTorchDtypeFloat16"),
1066
  gr.Radio(label="Translation - Using Bitsandbytes", visible=torch.cuda.is_available(), choices=[None, "int8", "int4"], value=app_config.translation_using_bitsandbytes, info="Load the float32 translation model into mixed-8bit or 4bit precision quantized model when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationUsingBitsandbytes"),
1067
  }
1068
 
1069
  common_vad_inputs = lambda : {
1070
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
1071
- gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow"),
1072
- gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize"),
1073
- gr.Number(label="VAD - Process Timeout (s)", precision=0, value=app_config.vad_process_timeout, elem_id="vadPocessTimeout"),
1074
  }
1075
 
1076
  common_word_timestamps_inputs = lambda : {
@@ -1148,12 +1159,16 @@ def create_ui(app_config: ApplicationConfig):
1148
  with gr.Tab(label="seamless") as seamlessTab:
1149
  with gr.Row():
1150
  inputDict.update(common_seamless_inputs())
 
 
 
1151
  m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
1152
  nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
1153
  mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
1154
  almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
1155
  madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
1156
  seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
 
1157
  with gr.Column():
1158
  with gr.Tab(label="URL") as UrlTab:
1159
  inputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
@@ -1164,14 +1179,14 @@ def create_ui(app_config: ApplicationConfig):
1164
  UrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [sourceInput] )
1165
  UploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [sourceInput] )
1166
  MicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [sourceInput] )
1167
- inputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task")})
1168
  with gr.Accordion("VAD options", open=False):
1169
  inputDict.update(common_vad_inputs())
1170
  if isFull:
1171
  inputDict.update({
1172
- gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding, elem_id = "vadPadding"),
1173
- gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window, elem_id = "vadPromptWindow"),
1174
- gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode", value=app_config.vad_initial_prompt_mode, elem_id = "vadInitialPromptMode")})
1175
  with gr.Accordion("Word Timestamps options", open=False):
1176
  inputDict.update(common_word_timestamps_inputs())
1177
  if isFull:
@@ -1250,12 +1265,16 @@ def create_ui(app_config: ApplicationConfig):
1250
  with gr.Tab(label="seamless") as seamlessTab:
1251
  with gr.Row():
1252
  inputDict.update(common_seamless_inputs())
 
 
 
1253
  m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
1254
  nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
1255
  mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
1256
  almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
1257
  madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
1258
  seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
 
1259
  with gr.Column():
1260
  inputDict.update({
1261
  gr.Dropdown(label="Input - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="inputLangName"),
 
921
  madlad400LangName: str = dataDict.pop("madlad400LangName")
922
  seamlessModelName: str = dataDict.pop("seamlessModelName")
923
  seamlessLangName: str = dataDict.pop("seamlessLangName")
924
+ LlamaModelName: str = dataDict.pop("LlamaModelName")
925
+ LlamaLangName: str = dataDict.pop("LlamaLangName")
926
 
927
  translationBatchSize: int = dataDict.pop("translationBatchSize")
928
  translationNoRepeatNgramSize: int = dataDict.pop("translationNoRepeatNgramSize")
 
956
  selectedModelName = seamlessModelName if seamlessModelName is not None and len(seamlessModelName) > 0 else "seamless-m4t-v2-large/facebook"
957
  selectedModel = next((modelConfig for modelConfig in self.app_config.models["seamless"] if modelConfig.name == selectedModelName), None)
958
  translationLang = get_lang_from_seamlessT_Tx_name(seamlessLangName)
959
+ elif translateInput == "Llama" and LlamaLangName is not None and len(LlamaLangName) > 0:
960
+ selectedModelName = LlamaModelName if LlamaModelName is not None and len(LlamaModelName) > 0 else "Meta-Llama-3-8B-Instruct-ct2-int8_float16/avan"
961
+ selectedModel = next((modelConfig for modelConfig in self.app_config.models["Llama"] if modelConfig.name == selectedModelName), None)
962
+ translationLang = get_lang_from_m2m100_name(LlamaLangName)
963
 
964
  if translationLang is not None:
965
  translationModel = TranslationModel(modelConfig=selectedModel, whisperLang=inputLang, translationLang=translationLang, batchSize=translationBatchSize, noRepeatNgramSize=translationNoRepeatNgramSize, numBeams=translationNumBeams, torchDtypeFloat16=translationTorchDtypeFloat16, usingBitsandbytes=translationUsingBitsandbytes)
 
1029
  ALMA_models = app_config.get_model_names("ALMA")
1030
  madlad400_models = app_config.get_model_names("madlad400")
1031
  seamless_models = app_config.get_model_names("seamless")
1032
+ Llama_models = app_config.get_model_names("Llama")
1033
  if not torch.cuda.is_available(): # Loading only quantized or models with medium-low parameters in an environment without GPU support.
1034
  nllb_models = list(filter(lambda nllb: any(name in nllb for name in ["-600M", "-1.3B", "-3.3B-ct2"]), nllb_models))
1035
  m2m100_models = list(filter(lambda m2m100: "12B" not in m2m100, m2m100_models))
 
1064
  gr.Dropdown(label="seamless - Model (for translate)", choices=seamless_models, elem_id="seamlessModelName"),
1065
  gr.Dropdown(label="seamless - Language", choices=sorted(get_lang_seamlessT_Tx_names()), elem_id="seamlessLangName"),
1066
  }
1067
+ common_Llama_inputs = lambda : {
1068
+ gr.Dropdown(label="Llama - Model (for translate)", choices=Llama_models, elem_id="LlamaModelName"),
1069
+ gr.Dropdown(label="Llama - Language", choices=sorted(get_lang_m2m100_names()), elem_id="LlamaLangName"),
1070
+ }
1071
 
1072
  common_translation_inputs = lambda : {
1073
  gr.Number(label="Translation - Batch Size", precision=0, value=app_config.translation_batch_size, elem_id="translationBatchSize"),
1074
+ gr.Number(label="Translation - No Repeat Ngram Size", precision=0, value=app_config.translation_no_repeat_ngram_size, elem_id="translationNoRepeatNgramSize", info="Prevent repetitions of ngrams with this size (set 0 to disable)."),
1075
+ gr.Number(label="Translation - Num Beams", precision=0, value=app_config.translation_num_beams, elem_id="translationNumBeams", info="Beam size (1 for greedy search)."),
1076
  gr.Checkbox(label="Translation - Torch Dtype float16", visible=torch.cuda.is_available(), value=app_config.translation_torch_dtype_float16, info="Load the float32 translation model with float16 when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationTorchDtypeFloat16"),
1077
  gr.Radio(label="Translation - Using Bitsandbytes", visible=torch.cuda.is_available(), choices=[None, "int8", "int4"], value=app_config.translation_using_bitsandbytes, info="Load the float32 translation model into mixed-8bit or 4bit precision quantized model when the system supports GPU (reducing VRAM usage, not applicable to models that have already been quantized, such as Ctranslate2, GPTQ, GGUF)", elem_id="translationUsingBitsandbytes"),
1078
  }
1079
 
1080
  common_vad_inputs = lambda : {
1081
  gr.Dropdown(choices=["none", "silero-vad", "silero-vad-skip-gaps", "silero-vad-expand-into-gaps", "periodic-vad"], value=app_config.default_vad, label="VAD", elem_id="vad"),
1082
+ gr.Number(label="VAD - Merge Window (s)", precision=0, value=app_config.vad_merge_window, elem_id="vadMergeWindow", info="If set, any adjacent speech sections that are at most this number of seconds apart will be automatically merged."),
1083
+ gr.Number(label="VAD - Max Merge Size (s)", precision=0, value=app_config.vad_max_merge_size, elem_id="vadMaxMergeSize", info="Disables merging of adjacent speech sections if they are this number of seconds long."),
1084
+ gr.Number(label="VAD - Process Timeout (s)", precision=0, value=app_config.vad_process_timeout, elem_id="vadPocessTimeout", info="This configures the number of seconds until a process is killed due to inactivity, freeing RAM and video memory. The default value is 30 minutes."),
1085
  }
1086
 
1087
  common_word_timestamps_inputs = lambda : {
 
1159
  with gr.Tab(label="seamless") as seamlessTab:
1160
  with gr.Row():
1161
  inputDict.update(common_seamless_inputs())
1162
+ with gr.Tab(label="Llama") as llamaTab:
1163
+ with gr.Row():
1164
+ inputDict.update(common_Llama_inputs())
1165
  m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
1166
  nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
1167
  mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
1168
  almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
1169
  madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
1170
  seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
1171
+ llamaTab.select(fn=lambda: "Llama", inputs = [], outputs= [translateInput] )
1172
  with gr.Column():
1173
  with gr.Tab(label="URL") as UrlTab:
1174
  inputDict.update({gr.Text(label="URL (YouTube, etc.)", elem_id = "urlData")})
 
1179
  UrlTab.select(fn=lambda: "urlData", inputs = [], outputs= [sourceInput] )
1180
  UploadTab.select(fn=lambda: "multipleFiles", inputs = [], outputs= [sourceInput] )
1181
  MicTab.select(fn=lambda: "microphoneData", inputs = [], outputs= [sourceInput] )
1182
+ inputDict.update({gr.Dropdown(choices=["transcribe", "translate"], label="Task", value=app_config.task, elem_id = "task", info="Select the task - either \"transcribe\" to transcribe the audio to text, or \"translate\" to translate it to English.")})
1183
  with gr.Accordion("VAD options", open=False):
1184
  inputDict.update(common_vad_inputs())
1185
  if isFull:
1186
  inputDict.update({
1187
+ gr.Number(label="VAD - Padding (s)", precision=None, value=app_config.vad_padding, elem_id = "vadPadding", info="The number of seconds (floating point) to add to the beginning and end of each speech section. Setting this to a number larger than zero ensures that Whisper is more likely to correctly transcribe a sentence in the beginning of a speech section. However, this also increases the probability of Whisper assigning the wrong timestamp to each transcribed line. The default value is 1 second."),
1188
+ gr.Number(label="VAD - Prompt Window (s)", precision=None, value=app_config.vad_prompt_window, elem_id = "vadPromptWindow", info="The text of a detected line will be included as a prompt to the next speech section, if the speech section starts at most this number of seconds after the line has finished. For instance, if a line ends at 10:00, and the next speech section starts at 10:04, the line's text will be included if the prompt window is 4 seconds or more (10:04 - 10:00 = 4 seconds)."),
1189
+ gr.Dropdown(choices=VAD_INITIAL_PROMPT_MODE_VALUES, label="VAD - Initial Prompt Mode", value=app_config.vad_initial_prompt_mode, elem_id = "vadInitialPromptMode", info="prepend_all_segments: prepend the initial prompt to each VAD segment, prepend_first_segment: just the first segment")})
1190
  with gr.Accordion("Word Timestamps options", open=False):
1191
  inputDict.update(common_word_timestamps_inputs())
1192
  if isFull:
 
1265
  with gr.Tab(label="seamless") as seamlessTab:
1266
  with gr.Row():
1267
  inputDict.update(common_seamless_inputs())
1268
+ with gr.Tab(label="Llama") as llamaTab:
1269
+ with gr.Row():
1270
+ inputDict.update(common_Llama_inputs())
1271
  m2m100Tab.select(fn=lambda: "m2m100", inputs = [], outputs= [translateInput] )
1272
  nllbTab.select(fn=lambda: "nllb", inputs = [], outputs= [translateInput] )
1273
  mt5Tab.select(fn=lambda: "mt5", inputs = [], outputs= [translateInput] )
1274
  almaTab.select(fn=lambda: "ALMA", inputs = [], outputs= [translateInput] )
1275
  madlad400Tab.select(fn=lambda: "madlad400", inputs = [], outputs= [translateInput] )
1276
  seamlessTab.select(fn=lambda: "seamless", inputs = [], outputs= [translateInput] )
1277
+ llamaTab.select(fn=lambda: "Llama", inputs = [], outputs= [translateInput] )
1278
  with gr.Column():
1279
  inputDict.update({
1280
  gr.Dropdown(label="Input - Language", choices=sorted(get_lang_whisper_names()), value=app_config.language, elem_id="inputLangName"),
config.json5 CHANGED
@@ -292,6 +292,14 @@
292
  "url": "facebook/seamless-m4t-v2-large",
293
  "type": "huggingface"
294
  }
 
 
 
 
 
 
 
 
295
  ]
296
  },
297
  // Configuration options that will be used if they are not specified in the command line arguments.
 
292
  "url": "facebook/seamless-m4t-v2-large",
293
  "type": "huggingface"
294
  }
295
+ ],
296
+ "Llama": [
297
+ {
298
+ "name": "Meta-Llama-3-8B-Instruct-ct2-int8_float16/avan",
299
+ "url": "avans06/Meta-Llama-3-8B-Instruct-ct2-int8_float16",
300
+ "type": "huggingface",
301
+ "tokenizer_url": "avans06/Meta-Llama-3-8B-Instruct-ct2-int8_float16"
302
+ }
303
  ]
304
  },
305
  // Configuration options that will be used if they are not specified in the command line arguments.
src/config.py CHANGED
@@ -50,7 +50,7 @@ class VadInitialPromptMode(Enum):
50
  return None
51
 
52
  class ApplicationConfig:
53
- def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless"], List[ModelConfig]],
54
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
55
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
56
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
@@ -185,7 +185,7 @@ class ApplicationConfig:
185
  # Load using json5
186
  data = json5.load(f)
187
  data_models = data.pop("models", [])
188
- models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless"], List[ModelConfig]] = {
189
  key: [ModelConfig(**item) for item in value]
190
  for key, value in data_models.items()
191
  }
 
50
  return None
51
 
52
  class ApplicationConfig:
53
+ def __init__(self, models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless", "Llama"], List[ModelConfig]],
54
  input_audio_max_duration: int = 600, share: bool = False, server_name: str = None, server_port: int = 7860,
55
  queue_concurrency_count: int = 1, delete_uploaded_files: bool = True,
56
  whisper_implementation: str = "whisper", default_model_name: str = "medium",
 
185
  # Load using json5
186
  data = json5.load(f)
187
  data_models = data.pop("models", [])
188
+ models: Dict[Literal["whisper", "m2m100", "nllb", "mt5", "ALMA", "madlad400", "seamless", "Llama"], List[ModelConfig]] = {
189
  key: [ModelConfig(**item) for item in value]
190
  for key, value in data_models.items()
191
  }
src/translation/translationModel.py CHANGED
@@ -27,7 +27,7 @@ class TranslationModel:
27
  localFilesOnly: bool = False,
28
  loadModel: bool = False,
29
  ):
30
- """Initializes the M2M100 / Nllb-200 / mt5 / ALMA / madlad400 / seamless-m4t translation model.
31
 
32
  Args:
33
  modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
@@ -230,6 +230,9 @@ class TranslationModel:
230
  if "ALMA" in self.modelPath:
231
  self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
232
  self.transModel = ctranslate2.Generator(**kwargsModel)
 
 
 
233
  else:
234
  if "nllb" in self.modelPath:
235
  kwargsTokenizer.update({"src_lang": self.whisperLang.nllb.code})
@@ -243,6 +246,8 @@ class TranslationModel:
243
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
244
  if "m2m100" in self.modelPath:
245
  self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
 
 
246
  elif "mt5" in self.modelPath:
247
  self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
248
  kwargsTokenizer.update({"pretrained_model_name_or_path": self.modelPath, "legacy": False})
@@ -382,6 +387,12 @@ class TranslationModel:
382
  output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
383
  target = output[0]
384
  result = self.transTokenizer.decode(target.sequences_ids[0])
 
 
 
 
 
 
385
  elif "madlad400" in self.modelPath:
386
  source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
387
  output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
@@ -424,7 +435,8 @@ _MODELS = ["nllb-200",
424
  "mt5",
425
  "ALMA",
426
  "madlad400",
427
- "seamless"]
 
428
 
429
  def check_model_name(name):
430
  return any(allowed_name in name for allowed_name in _MODELS)
 
27
  localFilesOnly: bool = False,
28
  loadModel: bool = False,
29
  ):
30
+ """Initializes the M2M100 / Nllb-200 / mt5 / ALMA / madlad400 / seamless-m4t / Llama translation model.
31
 
32
  Args:
33
  modelConfig: Config of the model to use (distilled-600M, distilled-1.3B,
 
230
  if "ALMA" in self.modelPath:
231
  self.ALMAPrefix = "Translate this from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ":\n" + self.whisperLang.whisper.names[0] + ": "
232
  self.transModel = ctranslate2.Generator(**kwargsModel)
233
+ elif "Llama" in self.modelPath:
234
+ self.roleSystem = {"role": "system", "content":"You are an excellent and professional translation master who understands languages from all around the world. Please directly translate the following sentence from " + self.whisperLang.whisper.names[0] + " to " + self.translationLang.whisper.names[0] + ", please simply provide the translation below without further explanation and without using any emojis."}
235
+ self.transModel = ctranslate2.Generator(**kwargsModel)
236
  else:
237
  if "nllb" in self.modelPath:
238
  kwargsTokenizer.update({"src_lang": self.whisperLang.nllb.code})
 
246
  self.transTokenizer = transformers.AutoTokenizer.from_pretrained(**kwargsTokenizer)
247
  if "m2m100" in self.modelPath:
248
  self.targetPrefix = [self.transTokenizer.lang_code_to_token[self.translationLang.m2m100.code]]
249
+ elif "Llama" in self.modelPath:
250
+ self.terminators = [self.transTokenizer.eos_token_id, self.transTokenizer.convert_tokens_to_ids("<|eot_id|>")]
251
  elif "mt5" in self.modelPath:
252
  self.mt5Prefix = self.whisperLang.whisper.code + "2" + self.translationLang.whisper.code + ": "
253
  kwargsTokenizer.update({"pretrained_model_name_or_path": self.modelPath, "legacy": False})
 
387
  output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, repetition_penalty=1.1, include_prompt_in_result=False) #, sampling_topk=40
388
  target = output[0]
389
  result = self.transTokenizer.decode(target.sequences_ids[0])
390
+ elif "Llama" in self.modelPath:
391
+ input_ids = self.transTokenizer.apply_chat_template([self.roleSystem, {"role": "user", "content": "'" + text + "', \n" + self.translationLang.whisper.names[0] + ":"}], tokenize=False, add_generation_prompt=True)
392
+ source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(input_ids))
393
+ output = self.transModel.generate_batch([source], max_length=max_length, max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams, sampling_temperature=0.7, sampling_topp=0.9, include_prompt_in_result=False, end_token=self.terminators)
394
+ target = output[0]
395
+ result = self.transTokenizer.decode(target.sequences_ids[0])
396
  elif "madlad400" in self.modelPath:
397
  source = self.transTokenizer.convert_ids_to_tokens(self.transTokenizer.encode(self.madlad400Prefix + text))
398
  output = self.transModel.translate_batch([source], max_batch_size=self.batchSize, no_repeat_ngram_size=self.noRepeatNgramSize, beam_size=self.numBeams)
 
435
  "mt5",
436
  "ALMA",
437
  "madlad400",
438
+ "seamless"
439
+ "Llama"]
440
 
441
  def check_model_name(name):
442
  return any(allowed_name in name for allowed_name in _MODELS)