File size: 3,762 Bytes
7534093
 
4fa90b9
7534093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
4fa90b9
7534093
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
import torch
import gradio as gr
from optimum.pipelines import pipeline
import ast

translation_task_names = {
    'English to French': 'translation_en_to_fr',
#    'French to English': 'translation_fr_to_en',
#    'English to Spanish': 'translation_en_to_es',
#    'Spanish to English': 'translation_es_to_en',
    'English to German': 'translation_en_to_de',
#    'German to English': 'translation_de_to_en',
#    'English to Italian': 'translation_en_to_it',
#    'Italian to English': 'translation_it_to_en',
    'English to Dutch': 'translation_en_to_nl',
    'Dutch to English': 'translation_nl_to_en',
#    'English to Portuguese': 'translation_en_to_pt',
#    'Portuguese to English': 'translation_pt_to_en',
    'English to Russian': 'translation_en_to_ru',
    'Russian to English': 'translation_ru_to_en',
    'English to Chinese': 'translation_en_to_zh',
    'Chinese to English': 'translation_zh_to_en',
#    'English to Japanese': 'translation_en_to_ja',
#    'Japanese to English': 'translation_ja_to_en',
    'English to Romanian': 'translation_en_to_ro',
    'Swedish to English': 'translation_SV_to_EN',
}

model_names = {
    'T5-Base': 't5-base',
    'T5-Small': 't5-small',
    'T5-Large': 't5-large',
    'Opus-En-ZH': 'liam168/trans-opus-mt-en-zh',
    'Opus-ZH-En': 'Helsinki-NLP/opus-mt-zh-en',
    'DDDSSS/translation_en-zh': 'DDDSSS/translation_en-zh',
    'T5-Base-nl-en': 'yhavinga/t5-base-36L-ccmatrix-multi',
    'T5-Small-nl-en': 'yhavinga/t5-small-24L-ccmatrix-multi',
    'Opus-Sv-En': 'Helsinki-NLP/opus-mt-sv-en',
    'Opus-En-Ru': 'Helsinki-NLP/opus-mt-en-ru',
    'Opus-Ru-En': 'Helsinki-NLP/opus-mt-ru-en',
}

# Create a dictionary to store loaded models
loaded_models = {}

# Simple translation function
def translate_text(model_choice, task_choice, text_input, load_in_8bit, device):
    model_key = (model_choice, task_choice, load_in_8bit)  # Create a tuple to represent the unique combination of task and 8bit loading

    # Check if the model is already loaded
    if model_key in loaded_models:
        translator = loaded_models[model_key]
    else:
        model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
        dtype = torch.float16 if load_in_8bit else torch.float32  # Set dtype based on the value of load_in_8bit
        translator = pipeline(task=translation_task_names[task_choice],
                            model=model_names[model_choice],  # Use selected model
                            device=device,  # Use selected device
                            model_kwargs=model_kwargs, 
                            torch_dtype=dtype,  # Set the floating point
                            accelerator="bettertransformer", # Use optimum bettertransformer
                            use_fast=True
                            )
        # Store the loaded model
        loaded_models[model_key] = translator

    translation = translator(text_input)[0]['translation_text']
    return str(translation).strip()

def launch(model_choice, task_choice, text_input, load_in_8bit, device):
    return translate_text(model_choice, task_choice, text_input, load_in_8bit, device)

model_dropdown = gr.Dropdown(choices=list(model_names.keys()), label='Select Model')
task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
text_input = gr.Textbox(label="Input Text")  # Single line text input
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
# https://www.gradio.app/docs/radio
device = gr.Radio(['cpu', 'cuda'], label='Select device', value='cpu')

iface = gr.Interface(launch, inputs=[model_dropdown, task_dropdown, text_input, load_in_8bit, device], 
                     outputs=gr.Textbox(type="text", label="Translation"))
iface.launch()