File size: 3,482 Bytes
afb6c6a
 
 
 
 
 
 
 
 
b601f49
afb6c6a
2f4418e
c50283a
afb6c6a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
74a97fe
 
 
 
 
 
 
 
 
 
afb6c6a
 
 
74a97fe
 
 
afb6c6a
 
 
 
 
74a97fe
afb6c6a
74a97fe
afb6c6a
 
 
 
 
 
 
 
 
 
 
 
 
c50283a
b601f49
c50283a
2f4418e
74a97fe
c50283a
afb6c6a
c9328c0
afb6c6a
 
 
 
 
 
74a97fe
 
 
 
afb6c6a
 
 
 
773033a
 
 
b601f49
 
74a97fe
 
 
 
 
 
b601f49
afb6c6a
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
import gradio as gr
import time
from transformers import AutoTokenizer, AutoModelForSeq2SeqLM, pipeline
from flores200_codes import flores_codes


def load_models():
    # build model and tokenizer
    model_name_dict = {
        "nllb-distilled-600M": "facebook/nllb-200-distilled-600M",
        "nllb-distilled-1.3B": "facebook/nllb-200-distilled-1.3B",
        # "nllb-1.3B": "facebook/nllb-200-1.3B",
        # "nllb-3.3B": "facebook/nllb-200-3.3B",
    }

    model_dict = {}

    for call_name, real_name in model_name_dict.items():
        print("\tLoading model: %s" % call_name)
        model = AutoModelForSeq2SeqLM.from_pretrained(real_name)
        tokenizer = AutoTokenizer.from_pretrained(real_name)
        model_dict[call_name + "_model"] = model
        model_dict[call_name + "_tokenizer"] = tokenizer

    return model_dict


def translation(model_name, source, target, text):
    start_time = time.time()
    source = flores_codes[source]
    target = flores_codes[target]

    model = model_dict[model_name + "_model"]
    tokenizer = model_dict[model_name + "_tokenizer"]

    translator = pipeline(
        "translation",
        model=model,
        tokenizer=tokenizer,
        src_lang=source,
        tgt_lang=target,
    )

    # sentence-wise translation
    sentences = text.split("\n")
    translated_sentences = []
    for sentence in sentences:
        translated_sentence = translator(sentence, max_length=400)[0][
            "translation_text"
        ]
        translated_sentences.append(translated_sentence)
    output = "\n".join(translated_sentences)

    end_time = time.time()

    # output = translator(text, max_length=400)
    # full_output = output
    # output = output[0]["translation_text"]
    result = {
        "inference_time": end_time - start_time,
        "source": source,
        "target": target,
        "result": output,
        # "full_output": full_output,
    }
    return result, output


if __name__ == "__main__":
    print("\tinit models")

    global model_dict

    model_dict = load_models()

    # define gradio demo
    lang_codes = list(flores_codes.keys())
    inputs = [
        gr.inputs.Radio(
            [
                "nllb-distilled-600M",
                "nllb-distilled-1.3B",
                # "nllb-1.3B",
                # "nllb-3.3B"
            ],
            label="NLLB Model",
            default="nllb-distilled-1.3B",
        ),
        gr.inputs.Dropdown(lang_codes, default="Najdi Arabic", label="Source"),
        gr.inputs.Dropdown(lang_codes, default="English", label="Target"),
        gr.inputs.Textbox(lines=5, label="Input text"),
    ]

    outputs = [
        gr.outputs.JSON(label="Metadata"),
        gr.outputs.Textbox(label="Output text"),
    ]

    title = "NLLB (No Language Left Behind) demo"

    demo_status = "Demo is running on CPU"
    description = f"""Using NLLB model, details: https://github.com/facebookresearch/fairseq/tree/nllb.

    {demo_status}"""
    examples = [
        ["nllb-distilled-1.3B", "Najdi Arabic", "English", "جلست اطفال"],
        [
            "nllb-distilled-600M",
            "Najdi Arabic",
            "English",
            "شد للبيع طابقين مع شرع له نظيف حق غمارتين",
        ],
    ]

    gr.Interface(
        translation,
        inputs,
        outputs,
        title=title,
        description=description,
        examples=examples,
        examples_per_page=50,
    ).launch()