Tenzin Gyalpo commited on
Commit
f4f9f94
1 Parent(s): 254144d

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +90 -0
app.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {
12
+ #'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
13
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
14
+ 'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
15
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
16
+ }
17
+
18
+ model_dict = {}
19
+
20
+ for call_name, real_name in model_name_dict.items():
21
+ print('\tLoading model: %s' % call_name)
22
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
23
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
24
+ model_dict[call_name+'_model'] = model
25
+ model_dict[call_name+'_tokenizer'] = tokenizer
26
+
27
+ return model_dict
28
+
29
+
30
+ def translation(source, target, text):
31
+ if len(model_dict) == 2:
32
+ model_name = 'nllb-distilled-1.3B'
33
+
34
+ start_time = time.time()
35
+ source = flores_codes[source]
36
+ target = flores_codes[target]
37
+
38
+ model = model_dict[model_name + '_model']
39
+ tokenizer = model_dict[model_name + '_tokenizer']
40
+
41
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
42
+ output = translator(text, max_length=400)
43
+
44
+ end_time = time.time()
45
+
46
+ full_output = output
47
+ output = output[0]['translation_text']
48
+ result = {'inference_time': end_time - start_time,
49
+ 'source': source,
50
+ 'target': target,
51
+ 'result': output,
52
+ 'full_output': full_output}
53
+ return result
54
+
55
+
56
+ if __name__ == '__main__':
57
+ print('\tinit models')
58
+
59
+ global model_dict
60
+
61
+ model_dict = load_models()
62
+
63
+ # define gradio demo
64
+ lang_codes = list(flores_codes.keys())
65
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
66
+ inputs = [gr.inputs.Dropdown(lang_codes, default='Standard Tibetan', label='Source'),
67
+ gr.inputs.Dropdown(lang_codes, default='English', label='Target'),
68
+ gr.inputs.Textbox(lines=5, label="Input text"),
69
+ ]
70
+
71
+ outputs = gr.outputs.JSON()
72
+
73
+ title = "NLLB distilled 1.3B distilled [སྐད་ཡིག་གཅིག་ཀྱང་མ་ལུས་པ།]"
74
+
75
+ demo_status = "Demo is running on CPU"
76
+ description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
77
+ examples = [
78
+ ['Standard Tibetan', 'English', 'ཁྱེད་ཀྱིས་ཟས་ཟིན་ནམ།']
79
+ ]
80
+
81
+ gr.Interface(translation,
82
+ inputs,
83
+ outputs,
84
+ title=title,
85
+ description=description,
86
+ examples=examples,
87
+ examples_per_page=50,
88
+ ).launch()
89
+
90
+