ashourzadeh7 commited on
Commit
3ca116b
·
verified ·
1 Parent(s): a3f20e3

Upload 3 files

Browse files
Files changed (3) hide show
  1. app.py +85 -0
  2. flores200_codes.py +9 -0
  3. requirements.txt +4 -0
app.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import torch
3
+ import gradio as gr
4
+ import time
5
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
+ from flores200_codes import flores_codes
7
+
8
+
9
+ def load_models():
10
+ # build model and tokenizer
11
+ model_name_dict = {#'nllb-finetuned-kutofa': 'ashourzadeh7/nllb-finetuned-kutofa',
12
+ #'nllb-1.3B': 'facebook/nllb-200-1.3B',
13
+ #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
14
+ 'nllb-3.3B': 'facebook/nllb-200-3.3B',
15
+ }
16
+
17
+ model_dict = {}
18
+
19
+ for call_name, real_name in model_name_dict.items():
20
+ print('\tLoading model: %s' % call_name)
21
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
22
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
23
+ model_dict[call_name+'_model'] = model
24
+ model_dict[call_name+'_tokenizer'] = tokenizer
25
+
26
+ return model_dict
27
+
28
+
29
+ def translation(source, target, text):
30
+ if len(model_dict) == 2:
31
+ model_name = 'nllb-3.3B'
32
+
33
+ start_time = time.time()
34
+ source = flores_codes[source]
35
+ target = flores_codes[target]
36
+
37
+ model = model_dict[model_name + '_model']
38
+ tokenizer = model_dict[model_name + '_tokenizer']
39
+
40
+ translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
41
+ output = translator(text, max_length=400)
42
+
43
+ end_time = time.time()
44
+
45
+ output = output[0]['translation_text']
46
+ result = {'inference_time': end_time - start_time,
47
+ 'source': source,
48
+ 'target': target,
49
+ 'result': output}
50
+ return result
51
+
52
+
53
+ if __name__ == '__main__':
54
+ print('\tinit models')
55
+
56
+ global model_dict
57
+
58
+ model_dict = load_models()
59
+
60
+ # define gradio demo
61
+ lang_codes = list(flores_codes.keys())
62
+ #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
63
+ inputs = [gr.components.Dropdown(label='Source', choices=lang_codes),
64
+ gr.components.Dropdown(label='Target', choices=lang_codes),
65
+ gr.components.Textbox(lines=5, label="Input text"),
66
+ ]
67
+
68
+ outputs = gr.components.JSON()
69
+
70
+ title = "NLLB distilled 600M demo"
71
+
72
+ demo_status = "Demo is running on CPU"
73
+ description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
74
+ examples = [
75
+ ['فارسی', 'کردی', 'سلام، حالتون خوبه؟']
76
+ ]
77
+
78
+ gr.Interface(translation,
79
+ inputs,
80
+ outputs,
81
+ title=title,
82
+ description=description,
83
+ ).launch()
84
+
85
+
flores200_codes.py ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ codes_as_string = '''فارسی pes_Arab
2
+ کردی ckb_Arab'''
3
+
4
+ codes_as_string = codes_as_string.split('\n')
5
+
6
+ flores_codes = {}
7
+ for code in codes_as_string:
8
+ lang, lang_code = code.split('\t')
9
+ flores_codes[lang] = lang_code
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ transformers
2
+ gradio
3
+ torch
4
+ httpx==0.24.1