jhj0517 commited on
Commit
a2c62c7
·
1 Parent(s): 007ba83

Refactor language validation logic

Browse files
modules/translation/nllb_inference.py CHANGED
@@ -37,6 +37,17 @@ class NLLBInference(TranslationBase):
37
  tgt_lang: str,
38
  progress: gr.Progress = gr.Progress()
39
  ):
 
 
 
 
 
 
 
 
 
 
 
40
  if model_size != self.current_model_size or self.model is None:
41
  print("\nInitializing NLLB Model..\n")
42
  progress(0, desc="Initializing NLLB Model..")
@@ -48,8 +59,7 @@ class NLLBInference(TranslationBase):
48
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
49
  cache_dir=os.path.join(self.model_dir, "tokenizers"),
50
  local_files_only=local_files_only)
51
- src_lang = NLLB_AVAILABLE_LANGS[src_lang]
52
- tgt_lang = NLLB_AVAILABLE_LANGS[tgt_lang]
53
  self.pipeline = pipeline("translation",
54
  model=self.model,
55
  tokenizer=self.tokenizer,
 
37
  tgt_lang: str,
38
  progress: gr.Progress = gr.Progress()
39
  ):
40
+ def validate_language(lang: str) -> str:
41
+ if lang in NLLB_AVAILABLE_LANGS:
42
+ return NLLB_AVAILABLE_LANGS[lang]
43
+ elif lang not in NLLB_AVAILABLE_LANGS.values():
44
+ raise ValueError(
45
+ f"Language '{lang}' is not supported. Use one of: {list(NLLB_AVAILABLE_LANGS.keys())}")
46
+ return lang
47
+
48
+ src_lang = validate_language(src_lang)
49
+ tgt_lang = validate_language(tgt_lang)
50
+
51
  if model_size != self.current_model_size or self.model is None:
52
  print("\nInitializing NLLB Model..\n")
53
  progress(0, desc="Initializing NLLB Model..")
 
59
  self.tokenizer = AutoTokenizer.from_pretrained(pretrained_model_name_or_path=model_size,
60
  cache_dir=os.path.join(self.model_dir, "tokenizers"),
61
  local_files_only=local_files_only)
62
+
 
63
  self.pipeline = pipeline("translation",
64
  model=self.model,
65
  tokenizer=self.tokenizer,