ashourzadeh7 commited on
Commit
518fac7
·
verified ·
1 Parent(s): 3ca116b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +10 -54
app.py CHANGED
@@ -6,74 +6,30 @@ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
6
  from flores200_codes import flores_codes
7
 
8
 
9
- def load_models():
10
- # build model and tokenizer
11
- model_name_dict = {#'nllb-finetuned-kutofa': 'ashourzadeh7/nllb-finetuned-kutofa',
12
- #'nllb-1.3B': 'facebook/nllb-200-1.3B',
13
- #'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
14
- 'nllb-3.3B': 'facebook/nllb-200-3.3B',
15
- }
16
 
17
- model_dict = {}
 
 
18
 
19
- for call_name, real_name in model_name_dict.items():
20
- print('\tLoading model: %s' % call_name)
21
- model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
22
- tokenizer = AutoTokenizer.from_pretrained(real_name)
23
- model_dict[call_name+'_model'] = model
24
- model_dict[call_name+'_tokenizer'] = tokenizer
25
-
26
- return model_dict
27
-
28
-
29
- def translation(source, target, text):
30
- if len(model_dict) == 2:
31
- model_name = 'nllb-3.3B'
32
-
33
- start_time = time.time()
34
- source = flores_codes[source]
35
- target = flores_codes[target]
36
-
37
- model = model_dict[model_name + '_model']
38
- tokenizer = model_dict[model_name + '_tokenizer']
39
-
40
- translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
41
- output = translator(text, max_length=400)
42
-
43
- end_time = time.time()
44
-
45
- output = output[0]['translation_text']
46
- result = {'inference_time': end_time - start_time,
47
- 'source': source,
48
- 'target': target,
49
- 'result': output}
50
- return result
51
 
52
 
53
  if __name__ == '__main__':
54
- print('\tinit models')
55
 
56
- global model_dict
57
 
58
- model_dict = load_models()
59
-
60
- # define gradio demo
61
- lang_codes = list(flores_codes.keys())
62
  #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
63
- inputs = [gr.components.Dropdown(label='Source', choices=lang_codes),
64
- gr.components.Dropdown(label='Target', choices=lang_codes),
65
- gr.components.Textbox(lines=5, label="Input text"),
66
- ]
67
 
68
- outputs = gr.components.JSON()
69
 
70
  title = "NLLB distilled 600M demo"
71
 
72
  demo_status = "Demo is running on CPU"
73
  description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
74
- examples = [
75
- ['فارسی', 'کردی', 'سلام، حالتون خوبه؟']
76
- ]
77
 
78
  gr.Interface(translation,
79
  inputs,
 
6
  from flores200_codes import flores_codes
7
 
8
 
 
 
 
 
 
 
 
9
 
10
+ def transfer(input):
11
+ with open(input, 'r', encoding="utf-8") as f:
12
+ text = f.read()
13
 
14
+ output_file = "out.txt"
15
+ with open(output_file, 'w', encoding="utf-8") as f:
16
+ file = f.write(text)
17
+ return file
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
 
20
  if __name__ == '__main__':
 
21
 
 
22
 
23
+
 
 
 
24
  #inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
25
+ inputs = [gr.components.file(label="Input File")]
 
 
 
26
 
27
+ outputs = gr.components.file(label="Translated File", value=file)
28
 
29
  title = "NLLB distilled 600M demo"
30
 
31
  demo_status = "Demo is running on CPU"
32
  description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
 
 
 
33
 
34
  gr.Interface(translation,
35
  inputs,