File size: 3,123 Bytes
3cc6dc1
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
918fd7f
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
import os
import torch
import gradio as gr
import time
import threading
from flask import Flask, request
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes


def load_models():
    # build model and tokenizer
    model_name_dict = {
        # 'nllb-distilled-600M': 'facebook/nllb-200-distilled-600M',
        # 'nllb-1.3B': 'facebook/nllb-200-1.3B',
        'nllb-distilled-1.3B': 'facebook/nllb-200-distilled-1.3B',
        # 'nllb-3.3B': 'facebook/nllb-200-3.3B',
    }

    model_dict = {}

    for call_name, real_name in model_name_dict.items():
        print('\tLoading model: %s' % call_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
        tokenizer = AutoTokenizer.from_pretrained(real_name)
        model_dict[call_name + '_model'] = model
        model_dict[call_name + '_tokenizer'] = tokenizer

    return model_dict


def translation(source, target, text):
    if len(model_dict) == 2:
        model_name = 'nllb-distilled-1.3B'

    start_time = time.time()
    source = flores_codes[source]
    target = flores_codes[target]

    model = model_dict[model_name + '_model']
    tokenizer = model_dict[model_name + '_tokenizer']

    translator = pipeline('translation', model=model, tokenizer=tokenizer, src_lang=source, tgt_lang=target)
    output = translator(text, max_length=400)

    end_time = time.time()

    full_output = output
    output = output[0]['translation_text']
    result = {'inference_time': end_time - start_time,
              'source': source,
              'target': target,
              'result': output,
              'full_output': full_output}
    return result


def start_flask():
    app = Flask(__name__)

    @app.route('/translate', methods=['POST'])
    def translate():
        source = request.form['source']
        target = request.form['target']
        text = request.form['text']
        result = translation(source, target, text)
        return result

    app.run()


if __name__ == '__main__':
    print('\tinit models')

    global model_dict

    model_dict = load_models()

    # define gradio demo
    lang_codes = list(flores_codes.keys())
    # inputs = [gr.inputs.Radio(['nllb-distilled-600M', 'nllb-1.3B', 'nllb-distilled-1.3B'], label='NLLB Model'),
    inputs = [gr.inputs.Dropdown(lang_codes, default='Yue Chinese', label='Source'),
              gr.inputs.Dropdown(lang_codes, default='English', label='Target'),
              gr.inputs.Textbox(lines=5, label="Input text"),
              ]

    outputs = gr.outputs.JSON()

    title = "NLLB distilled 1.3B distilled demo"

    demo_status = "Demo is running on CPU"
    description = f"Details: https://github.com/facebookresearch/fairseq/tree/nllb. {demo_status}"
    examples = [
        ['Yue Chinese', 'English', '你食咗飯未?']
    ]

    gr.Interface(translation,
                 inputs,
                 outputs,
                 title=title,
                 description=description,
                 examples=examples,
                 examples_per_page=50,
                 ).launch(share=True)