yaya-sy commited on
Commit
ac8cea6
1 Parent(s): 2260d99

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +85 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+
7
+
8
+ def load_models():
9
+ # build model and tokenizer
10
+ model_name_dict = {'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
11
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
12
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
13
+ #'nllb-3.3B': 'facebook/nllb-200-3.3B',
14
+ }
15
+
16
+ model_dict = {}
17
+
18
+ for call_name, real_name in model_name_dict.items():
19
+ print('\tLoading model: %s' % call_name)
20
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
21
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
22
+ model_dict[call_name+'_model'] = model
23
+ model_dict[call_name+'_tokenizer'] = tokenizer
24
+
25
+ return model_dict
26
+
27
+
28
+ def translation(source, target, text):
29
+ if len(model_dict) == 2:
30
+ model_name = 'nllb-distilled-600M'
31
+
32
+ start_time = time.time()
33
+
34
+ model = model_dict[model_name + '_model']
35
+ tokenizer = model_dict[model_name + '_tokenizer']
36
+
37
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
38
+ output = translator(text, max_length=400)
39
+
40
+ end_time = time.time()
41
+
42
+ output = output[0]['translation_text']
43
+ result = {'inference_time': end_time - start_time,
44
+ 'source': source,
45
+ 'target': target,
46
+ 'result': output}
47
+ return result
48
+
49
+
50
+ if __name__ == '__main__':
51
+ print('\tinit models')
52
+
53
+ global model_dict
54
+
55
+ model_dict = load_models()
56
+
57
+ # define gradio demo
58
+ lang_codes = ["eng_Latn", "fuv_Latn", "fra_Latn", "arb_Arab"]
59
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
60
+ inputs = [gr.inputs.Dropdown(lang_codes, default='fra_Latn', label='Source'),
61
+ gr.inputs.Dropdown(lang_codes, default='fuv_Latn', label='Target'),
62
+ gr.inputs.Textbox(lines=5, label="Input text"),
63
+ ]
64
+
65
+ outputs = gr.outputs.JSON()
66
+
67
+ title = "Fulfulde translator"
68
+
69
+ demo_status = "Demo is running on CPU"
70
+ description = "to French, English or Arabic and vice-versa translation demo using NLLB."
71
+ examples = [
72
+ ['fra_Latn', 'fuv_latn', 'La traduction est une tâche facile.']
73
+ ]
74
+
75
+ gr.Interface(translation,
76
+ inputs,
77
+ outputs,
78
+ title=title,
79
+ description=description,
80
+ examples=examples,
81
+ examples_per_page=50,
82
+ ).launch()
83
+
84
+
85
+