abinashbordoloi commited on
Commit
708d084
·
1 Parent(s): 8485848

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +125 -42
app.py CHANGED
@@ -1,66 +1,149 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
  import gradio as gr
2
  import time
3
- from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, pipeline
4
- from supported_languages import LANGS
5
 
6
- def load_model():
7
- # model_name = 'nllb-moe-54b'
8
-
9
- model_name = 'nllb-200-distilled-600M'
10
- print('\tLoading model: %s' % model_name)
11
- model = AutoModelForSeq2SeqLM.from_pretrained(f'facebook/{model_name}')
12
- tokenizer = NllbTokenizer.from_pretrained(f'facebook/{model_name}')
13
- return model, tokenizer
14
 
15
- model, tokenizer = load_model()
 
 
 
 
 
 
 
16
 
17
  def translation(source, target, text):
 
 
 
 
 
 
18
  start_time = time.time()
19
- source_code = LANGS[source]
20
- target_code = LANGS[target]
21
 
22
- source_langauge = source
23
- target_language = target
24
 
25
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
26
  output = translator(text, max_length=400)
27
- input_text = text
28
  end_time = time.time()
29
 
30
  full_output = output
31
  output = output[0]['translation_text']
32
- result = {
33
- 'inference_time': end_time - start_time,
34
- 'source': source_language,
35
- 'target': target_language,
36
- 'input_text': input_text,
37
- 'result': output,
38
- 'full_output': full_output
39
- }
40
  return result
41
 
42
  if __name__ == '__main__':
43
- # Define gradio demo
44
- lang_codes = list(LANGS.keys())
45
- inputs = [
46
- gr.Dropdown(lang_codes, label='Source'),
47
- gr.Dropdown(lang_codes, label='Target'),
48
- gr.Textbox(lines=5, label="Input text"),
49
- ]
 
 
 
 
50
 
51
  outputs = gr.JSON()
52
 
53
- title = "NLLB distilled 1.3B distilled【多语言翻译器】"
54
  demo_status = "Demo is running on CPU"
55
  description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
56
-
 
 
 
 
 
 
 
 
 
57
 
58
- gr.Interface(
59
- translation,
60
- inputs,
61
- outputs,
62
- title=title,
63
- description=description,
64
-
65
- examples_per_page=50,
66
- ).launch()
 
1
+ # import gradio as gr
2
+ # import time
3
+ # from transformers import NllbTokenizer, AutoModelForSeq2SeqLM, pipeline
4
+ # from supported_languages import LANGS
5
+
6
+ # def load_model():
7
+ # # model_name = 'nllb-moe-54b'
8
+
9
+ # model_name = 'nllb-200-distilled-600M'
10
+ # print('\tLoading model: %s' % model_name)
11
+ # model = AutoModelForSeq2SeqLM.from_pretrained(f'facebook/{model_name}')
12
+ # tokenizer = NllbTokenizer.from_pretrained(f'facebook/{model_name}')
13
+ # return model, tokenizer
14
+
15
+ # model, tokenizer = load_model()
16
+
17
+ # def translation(source, target, text):
18
+ # start_time = time.time()
19
+ # source_code = LANGS[source]
20
+ # target_code = LANGS[target]
21
+
22
+ # source_langauge = source
23
+ # target_language = target
24
+
25
+ # translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
26
+ # output = translator(text, max_length=400)
27
+ # input_text = text
28
+ # end_time = time.time()
29
+
30
+ # full_output = output
31
+ # output = output[0]['translation_text']
32
+ # result = {
33
+ # 'inference_time': end_time - start_time,
34
+ # 'source': source_language,
35
+ # 'target': target_language,
36
+ # 'input_text': input_text,
37
+ # 'result': output,
38
+ # 'full_output': full_output
39
+ # }
40
+ # return result
41
+
42
+ # if __name__ == '__main__':
43
+ # # Define gradio demo
44
+ # lang_codes = list(LANGS.keys())
45
+ # inputs = [
46
+ # gr.Dropdown(lang_codes, label='Source'),
47
+ # gr.Dropdown(lang_codes, label='Target'),
48
+ # gr.Textbox(lines=5, label="Input text"),
49
+ # ]
50
+
51
+ # outputs = gr.JSON()
52
+
53
+ # title = "NLLB distilled 1.3B distilled【多语言翻译器】"
54
+ # demo_status = "Demo is running on CPU"
55
+ # description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
56
+
57
+
58
+ # gr.Interface(
59
+ # translation,
60
+ # inputs,
61
+ # outputs,
62
+ # title=title,
63
+ # description=description,
64
+
65
+ # examples_per_page=50,
66
+ # ).launch()
67
+
68
+
69
+ import os
70
+ import torch
71
  import gradio as gr
72
  import time
73
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
74
+ from flores200_codes import flores_codes
75
 
76
+ def load_models():
77
+ # build model and tokenizer
78
+ model_name_dict = {
79
+ 'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
80
+ }
81
+
82
+ model_dict = {}
 
83
 
84
+ for call_name, real_name in model_name_dict.items():
85
+ print('\tLoading model: %s' % call_name)
86
+ model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
87
+ tokenizer = AutoTokenizer.from_pretrained(real_name)
88
+ model_dict[call_name+'_model'] = model
89
+ model_dict[call_name+'_tokenizer'] = tokenizer
90
+
91
+ return model_dict
92
 
93
  def translation(source, target, text):
94
+ model_name = 'nllb-distilled-1.3B'
95
+
96
+ if model_name+'_model' not in model_dict:
97
+ print(f"Model '{model_name}' not found in model_dict.")
98
+ return
99
+
100
  start_time = time.time()
101
+ source = flores_codes[source]
102
+ target = flores_codes[target]
103
 
104
+ model = model_dict[model_name + '_model']
105
+ tokenizer = model_dict[model_name + '_tokenizer']
106
 
107
  translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
108
  output = translator(text, max_length=400)
109
+
110
  end_time = time.time()
111
 
112
  full_output = output
113
  output = output[0]['translation_text']
114
+ result = {'inference_time': end_time - start_time,
115
+ 'source': source,
116
+ 'target': target,
117
+ 'result': output,
118
+ 'full_output': full_output}
 
 
 
119
  return result
120
 
121
  if __name__ == '__main__':
122
+ print('\tinit models')
123
+
124
+ global model_dict
125
+ model_dict = load_models()
126
+
127
+ # define gradio demo
128
+ lang_codes = list(flores_codes.keys())
129
+ inputs = [gr.Dropdown(lang_codes, label='Source'),
130
+ gr.Dropdown(lang_codes, label='Target'),
131
+ gr.Textbox(lines=5, label="Input text"),
132
+ ]
133
 
134
  outputs = gr.JSON()
135
 
136
+ title = "NLLB distilled 1.3B distilled​``【oaicite:0】``​"
137
  demo_status = "Demo is running on CPU"
138
  description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
139
+ examples = [['Chinese (Simplified)', 'English', '你吃饭了吗?']]
140
+
141
+ gr.Interface(translation,
142
+ inputs,
143
+ outputs,
144
+ title=title,
145
+ description=description,
146
+ examples=examples,
147
+ examples_per_page=50,
148
+ ).launch()
149