sethanimesh commited on
Commit
d28a390
1 Parent(s): 9c87c10
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
12
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
13
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
14
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
15
+ }
16
+
17
+ model_dict = {}
18
+
19
+ for call_name, real_name in model_name_dict.items():
20
+ print('\tLoading model: %s' % call_name)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
22
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
23
+ model_dict[call_name+'_model'] = model
24
+ model_dict[call_name+'_tokenizer'] = tokenizer
25
+
26
+ return model_dict
27
+
28
+
29
+ def translation(source, target, text):
30
+ if len(model_dict) == 2:
31
+ model_name = 'nllb-distilled-600M'
32
+
33
+ start_time = time.time()
34
+ source = flores_codes[source]
35
+ target = flores_codes[target]
36
+
37
+ model = model_dict[model_name + '_model']
38
+ tokenizer = model_dict[model_name + '_tokenizer']
39
+
40
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
41
+ output = translator(text, max_length=400)
42
+
43
+ end_time = time.time()
44
+
45
+ output = output[0]['translation_text']
46
+ result = {'inference_time': end_time - start_time,
47
+ 'source': source,
48
+ 'target': target,
49
+ 'result': output}
50
+ return result
51
+
52
+
53
+ if __name__ == '__main__':
54
+ print('\tinit models')
55
+
56
+ global model_dict
57
+
58
+ model_dict = load_models()
59
+
60
+ # define gradio demo
61
+ lang_codes = list(flores_codes.keys())
62
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
63
+ inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
64
+ gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
65
+ gr.inputs.Textbox(lines=5, label="Input text"),
66
+ ]
67
+
68
+ outputs = gr.outputs.JSON()
69
+
70
+ title = "NLLB distilled 600M demo"
71
+
72
+ demo_status = "Demo is running on CPU"
73
+ description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
74
+ examples = [
75
+ ['English', 'Korean', 'Hi. nice to meet you']
76
+ ]
77
+
78
+ gr.Interface(translation,
79
+ inputs,
80
+ outputs,
81
+ title=title,
82
+ description=description,
83
+ ).launch()
84
+
85
+