Spaces:
Runtime error
Runtime error
File size: 3,244 Bytes
2e6a359 1700804 2e6a359 1700804 37854cd 1700804 2e6a359 2bbca52 37854cd 2bbca52 2e6a359 2bbca52 2e6a359 2bbca52 2e6a359 2bbca52 2e6a359 2bbca52 2e6a359 a760944 2e6a359 2bbca52 365f5c4 2e6a359 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 |
import torch
import gradio as gr
from transformers import pipeline
import ast
translation_task_names = {
'English to French': 'translation_en_to_fr',
# 'French to English': 'translation_fr_to_en',
# 'English to Spanish': 'translation_en_to_es',
# 'Spanish to English': 'translation_es_to_en',
'English to German': 'translation_en_to_de',
# 'German to English': 'translation_de_to_en',
# 'English to Italian': 'translation_en_to_it',
# 'Italian to English': 'translation_it_to_en',
# 'English to Dutch': 'translation_en_to_nl',
# 'Dutch to English': 'translation_nl_to_en',
# 'English to Portuguese': 'translation_en_to_pt',
# 'Portuguese to English': 'translation_pt_to_en',
# 'English to Russian': 'translation_en_to_ru',
# 'Russian to English': 'translation_ru_to_en',
'English to Chinese': 'translation_en_to_zh',
# 'Chinese to English': 'translation_zh_to_en',
# 'English to Japanese': 'translation_en_to_ja',
# 'Japanese to English': 'translation_ja_to_en',
'English to Romanian': 'translation_en_to_ro',
}
model_names = {
'T5-Base': 't5-base',
'T5-Small': 't5-small',
'T5-Large': 't5-large',
'Opus-En-ZH': 'liam168/trans-opus-mt-en-zh'
}
# Create a dictionary to store loaded models
loaded_models = {}
# Simple translation function
def translate_text(model_choice, task_choice, text_input, load_in_8bit, device):
model_key = (model_choice, task_choice, load_in_8bit) # Create a tuple to represent the unique combination of task and 8bit loading
# Check if the model is already loaded
if model_key in loaded_models:
translator = loaded_models[model_key]
else:
model_kwargs = {"load_in_8bit": load_in_8bit} if load_in_8bit else {}
dtype = torch.float16 if load_in_8bit else torch.float32 # Set dtype based on the value of load_in_8bit
translator = pipeline(task=translation_task_names[task_choice],
model=model_names[model_choice], # Use selected model
device=device, # Use selected device
model_kwargs=model_kwargs,
torch_dtype=dtype, # Set the floating point
use_fast=True
)
# Store the loaded model
loaded_models[model_key] = translator
translation = translator(text_input)[0]['translation_text']
return str(translation).strip()
def launch(model_choice, task_choice, text_input, load_in_8bit, device):
return translate_text(model_choice, task_choice, text_input, load_in_8bit, device)
model_dropdown = gr.Dropdown(choices=list(model_names.keys()), label='Select Model')
task_dropdown = gr.Dropdown(choices=list(translation_task_names.keys()), label='Select Translation Task')
text_input = gr.Textbox(label="Input Text") # Single line text input
load_in_8bit = gr.Checkbox(label="Load model in 8bit")
# https://www.gradio.app/docs/radio
device = gr.Radio(['cpu', 'cuda'], label='Select device', value='cpu')
iface = gr.Interface(launch, inputs=[model_dropdown, task_dropdown, text_input, load_in_8bit, device],
outputs=gr.Textbox(type="text", label="Translation"))
iface.launch() |