kamau1 commited on
Commit
156cab1
1 Parent(s): 3c6bf8a

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {
12
+ #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
13
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
14
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
15
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
16
+ 'nllb-moe-54b': 'facebook/nllb-moe-54b',
17
+ }
18
+
19
+ model_dict = {}
20
+
21
+ for call_name, real_name in model_name_dict.items():
22
+ print('\tLoading model: %s' % call_name)
23
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
24
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
25
+ model_dict[call_name+'_model'] = model
26
+ model_dict[call_name+'_tokenizer'] = tokenizer
27
+
28
+ return model_dict
29
+
30
+
31
+ def translation(source, target, text):
32
+ if len(model_dict) == 2:
33
+ model_name = 'nllb-distilled-600M'
34
+
35
+ start_time = time.time()
36
+ source = flores_codes[source]
37
+ target = flores_codes[target]
38
+
39
+ model = model_dict[model_name + '_model']
40
+ tokenizer = model_dict[model_name + '_tokenizer']
41
+
42
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
43
+ output = translator(text, max_length=400)
44
+
45
+ end_time = time.time()
46
+
47
+ output = output[0]['translation_text']
48
+ result = {'inference_time': end_time - start_time,
49
+ 'source': source,
50
+ 'target': target,
51
+ 'result': output}
52
+ return result
53
+
54
+
55
+ if __name__ == '__main__':
56
+ print('\tinit models')
57
+
58
+ global model_dict
59
+
60
+ model_dict = load_models()
61
+
62
+ # define gradio demo
63
+ lang_codes = list(flores_codes.keys())
64
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
65
+ inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
66
+ gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
67
+ gr.inputs.Textbox(lines=5, label="Input text"),
68
+ ]
69
+
70
+ outputs = gr.outputs.JSON()
71
+
72
+ title = "NLLB distilled 600M demo"
73
+
74
+ demo_status = "Demo is running on CPU"
75
+ description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
76
+ examples = [
77
+ ['English', 'Korean', 'Hi. nice to meet you']
78
+ ]
79
+
80
+ gr.Interface(translation,
81
+ inputs,
82
+ outputs,
83
+ title=title,
84
+ description=description,
85
+ ).launch()