dipesh1701 commited on
Commit
93d168d
1 Parent(s): 48ff56c
Files changed (1) hide show
  1. app.py +11 -15
app.py CHANGED
@@ -1,12 +1,11 @@
1
  import torch
2
  import gradio as gr
3
  import time
4
- import asyncio
5
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from flores200_codes import flores_codes
7
 
8
  # Load models and tokenizers once during initialization
9
- async def load_models():
10
  model_name_dict = {
11
  "nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
12
  }
@@ -15,8 +14,8 @@ async def load_models():
15
 
16
  for call_name, real_name in model_name_dict.items():
17
  print("\tLoading model:", call_name)
18
- model = await asyncio.to_thread(AutoModelForSeq2SeqLM.from_pretrained, real_name)
19
- tokenizer = await asyncio.to_thread(AutoTokenizer.from_pretrained, real_name)
20
  model_dict[call_name] = {
21
  "model": model,
22
  "tokenizer": tokenizer,
@@ -28,14 +27,14 @@ async def load_models():
28
  def translate_text(source_lang, target_lang, input_text, model_dict):
29
  model_name = "nllb-distilled-600M"
30
 
31
- start_time = time.time()
32
- source_code = flores_codes[source_lang]
33
- target_code = flores_codes[target_lang]
34
-
35
  if model_name in model_dict:
36
  model = model_dict[model_name]["model"]
37
  tokenizer = model_dict[model_name]["tokenizer"]
38
 
 
 
 
 
39
  translator = pipeline(
40
  "translation",
41
  model=model,
@@ -57,11 +56,11 @@ def translate_text(source_lang, target_lang, input_text, model_dict):
57
  else:
58
  raise KeyError(f"Model '{model_name}' not found in model_dict")
59
 
60
- async def main():
61
  print("\tInitializing models")
62
 
63
  # Load models and tokenizers
64
- model_dict = await load_models()
65
 
66
  lang_codes = list(flores_codes.keys())
67
  inputs = [
@@ -72,10 +71,10 @@ async def main():
72
 
73
  outputs = gr.outputs.JSON()
74
 
75
- title = "Masterful Translator"
76
 
77
  app_description = (
78
- "This is a beta version of the Masterful Translator that utilizes pre-trained language models for translation."
79
  )
80
  examples = [["English", "Nepali", "Hello, how are you?"]]
81
 
@@ -88,6 +87,3 @@ async def main():
88
  examples=examples,
89
  examples_per_page=50,
90
  ).launch()
91
-
92
- if __name__ == "__main__":
93
- asyncio.run(main())
 
1
  import torch
2
  import gradio as gr
3
  import time
 
4
  from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
5
  from flores200_codes import flores_codes
6
 
7
  # Load models and tokenizers once during initialization
8
+ def load_models():
9
  model_name_dict = {
10
  "nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
11
  }
 
14
 
15
  for call_name, real_name in model_name_dict.items():
16
  print("\tLoading model:", call_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
19
  model_dict[call_name] = {
20
  "model": model,
21
  "tokenizer": tokenizer,
 
27
  def translate_text(source_lang, target_lang, input_text, model_dict):
28
  model_name = "nllb-distilled-600M"
29
 
 
 
 
 
30
  if model_name in model_dict:
31
  model = model_dict[model_name]["model"]
32
  tokenizer = model_dict[model_name]["tokenizer"]
33
 
34
+ start_time = time.time()
35
+ source_code = flores_codes[source_lang]
36
+ target_code = flores_codes[target_lang]
37
+
38
  translator = pipeline(
39
  "translation",
40
  model=model,
 
56
  else:
57
  raise KeyError(f"Model '{model_name}' not found in model_dict")
58
 
59
+ if __name__ == "__main__":
60
  print("\tInitializing models")
61
 
62
  # Load models and tokenizers
63
+ model_dict = load_models()
64
 
65
  lang_codes = list(flores_codes.keys())
66
  inputs = [
 
71
 
72
  outputs = gr.outputs.JSON()
73
 
74
+ title = "The Master Betters Translator"
75
 
76
  app_description = (
77
+ "This is a beta version of The Master Betters Translator that utilizes pre-trained language models for translation."
78
  )
79
  examples = [["English", "Nepali", "Hello, how are you?"]]
80
 
 
87
  examples=examples,
88
  examples_per_page=50,
89
  ).launch()