futranbg commited on
Commit
743df17
1 Parent(s): 0829b3f

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +86 -0
app.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+ def load_models():
9
+ # build model and tokenizer
10
+ model_name_dict = {#'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
11
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
12
+ 'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
13
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
14
+ }
15
+
16
+ model_dict = {}
17
+
18
+ for call_name, real_name in model_name_dict.items():
19
+ print('\tLoading model: %s' % call_name)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
21
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
22
+ model_dict[call_name+'_model'] = model
23
+ model_dict[call_name+'_tokenizer'] = tokenizer
24
+
25
+ return model_dict
26
+
27
+
28
+ def translation(source, target, text):
29
+ if len(model_dict) == 2:
30
+ #model_name = 'nllb-distilled-600M'
31
+ model_name = 'nllb-distilled-1.3B'
32
+
33
+ start_time = time.time()
34
+ source = flores_codes[source]
35
+ target = flores_codes[target]
36
+
37
+ model = model_dict[model_name + '_model']
38
+ tokenizer = model_dict[model_name + '_tokenizer']
39
+
40
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
41
+ chunks = text.splitlines(True)
42
+ output = ""
43
+ for chunk in chunks:
44
+ stchunk = translator(chunk, max_length=500, skip_special_tokens=True, num_beams=5)
45
+ output += stchunk[0]['translation_text']+"\n"
46
+
47
+ end_time = time.time()
48
+
49
+ result = {'inference_time': end_time - start_time,
50
+ 'source': source,
51
+ 'target': target,
52
+ 'result': output}
53
+ return result
54
+
55
+
56
+ if __name__ == '__main__':
57
+ print('\tinit models')
58
+
59
+ global model_dict
60
+
61
+ model_dict = load_models()
62
+
63
+ # define gradio demo
64
+ lang_codes = list(flores_codes.keys())
65
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
66
+ inputs = [gr.inputs.Dropdown(lang_codes, default='English', label='Source'),
67
+ gr.inputs.Dropdown(lang_codes, default='Korean', label='Target'),
68
+ gr.inputs.Textbox(lines=5, label="Input text"),
69
+ ]
70
+
71
+ outputs = gr.outputs.JSON()
72
+
73
+ title = "NLLB distilled 600M demo"
74
+
75
+ demo_status = "Demo is running on CPU"
76
+ description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
77
+ examples = [
78
+ ['English', 'Korean', 'Hi. nice to meet you']
79
+ ]
80
+
81
+ gr.Interface(translation,
82
+ inputs,
83
+ outputs,
84
+ title=title,
85
+ description=description,
86
+ ).launch()