Danil commited on
Commit
f4366b8
1 Parent(s): 4aafedf

Create new file

Browse files
Files changed (1) hide show
  1. app.py +83 -0
app.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B'}
12
+
13
+ model_dict = {}
14
+
15
+ for call_name, real_name in model_name_dict.items():
16
+ print('\tLoading model: %s' % call_name)
17
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
18
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
19
+ model_dict[call_name+'_model'] = model
20
+ model_dict[call_name+'_tokenizer'] = tokenizer
21
+
22
+ return model_dict
23
+
24
+
25
+ def translation(source, target, text):
26
+ if len(model_dict) == 2:
27
+ model_name = 'nllb-distilled-1.3B'
28
+
29
+ start_time = time.time()
30
+ source = flores_codes[source]
31
+ target = flores_codes[target]
32
+
33
+ model = model_dict[model_name + '_model']
34
+ tokenizer = model_dict[model_name + '_tokenizer']
35
+
36
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
37
+ output = translator(text, max_length=400)
38
+
39
+ end_time = time.time()
40
+
41
+ output = output[0]['translation_text']
42
+ result = {'inference_time': end_time - start_time,
43
+ 'source': source,
44
+ 'target': target,
45
+ 'result': output}
46
+ return result
47
+
48
+
49
+ if __name__ == '__main__':
50
+ print('\tinit models')
51
+
52
+ global model_dict
53
+
54
+ model_dict = load_models()
55
+
56
+ # define gradio demo
57
+ lang_codes = list(flores_codes.keys())
58
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
59
+ inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
60
+ gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
61
+ gr.inputs.Textbox(lines=5, label="Input text"),
62
+ ]
63
+
64
+ outputs = gr.outputs.JSON()
65
+
66
+ title = "NLLB distilled 1.3B demo"
67
+
68
+ demo_status = "Demo is running on CPU"
69
+ description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
70
+ examples = [
71
+ ['English', 'Korean', 'Hi. nice to meet you']
72
+ ]
73
+
74
+ gr.Interface(translation,
75
+ inputs,
76
+ outputs,
77
+ title=title,
78
+ description=description,
79
+ examples=examples,
80
+ examples_per_page=50,
81
+ ).launch()
82
+
83
+