dipesh1701 commited on
Commit
463444e
1 Parent(s): 997a555
Files changed (1) hide show
  1. app.py +96 -0
app.py ADDED
@@ -0,0 +1,96 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {
12
+ "nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
13
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
14
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
15
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
16
+ }
17
+
18
+ model_dict = {}
19
+
20
+ for call_name, real_name in model_name_dict.items():
21
+ print("\tLoading model: %s" % call_name)
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
23
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
24
+ model_dict[call_name + "_model"] = model
25
+ model_dict[call_name + "_tokenizer"] = tokenizer
26
+
27
+ return model_dict
28
+
29
+
30
+ def translation(source, target, text):
31
+ if len(model_dict) == 2:
32
+ model_name = "nllb-distilled-600M"
33
+
34
+ start_time = time.time()
35
+ source = flores_codes[source]
36
+ target = flores_codes[target]
37
+
38
+ model = model_dict[model_name + "_model"]
39
+ tokenizer = model_dict[model_name + "_tokenizer"]
40
+
41
+ translator = pipeline(
42
+ "translation",
43
+ model=model,
44
+ tokenizer=tokenizer,
45
+ src_lang=source,
46
+ tgt_lang=target,
47
+ )
48
+ output = translator(text, max_length=400)
49
+
50
+ end_time = time.time()
51
+
52
+ output = output[0]["translation_text"]
53
+ result = {
54
+ "inference_time": end_time - start_time,
55
+ "source": source,
56
+ "target": target,
57
+ "result": output,
58
+ }
59
+ return result
60
+
61
+
62
+ if __name__ == "__main__":
63
+ print("\tinit models")
64
+
65
+ global model_dict
66
+
67
+ model_dict = load_models()
68
+
69
+ # define gradio demo
70
+ lang_codes = list(flores_codes.keys())
71
+ # inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
72
+ inputs = [
73
+ gr.inputs.Dropdown(lang_codes, default="English", label="Source"),
74
+ gr.inputs.Dropdown(lang_codes, default="Korean", label="Target"),
75
+ gr.inputs.Textbox(lines=5, label="Input text"),
76
+ ]
77
+
78
+ outputs = gr.outputs.JSON()
79
+
80
+ title = "NLLB distilled 600M demo"
81
+
82
+ demo_status = "Demo is running on CPU"
83
+ description = (
84
+ f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
85
+ )
86
+ examples = [["English", "Korean", "Hi. nice to meet you"]]
87
+
88
+ gr.Interface(
89
+ translation,
90
+ inputs,
91
+ outputs,
92
+ title=title,
93
+ description=description,
94
+ examples=examples,
95
+ examples_per_page=50,
96
+ ).launch()