import gradio as gr | |
from transformers import T5ForConditionalGeneration, T5Tokenizer | |
model_name = 'jbochi/madlad400-3b-mt' | |
model = T5ForConditionalGeneration.from_pretrained(model_name, device_map="auto") | |
tokenizer = T5Tokenizer.from_pretrained(model_name) | |
def translate_to_russian(input_text): | |
full_input_text = "<2ru>" + input_text | |
input_ids = tokenizer(full_input_text, return_tensors="pt").input_ids.to(model.device) | |
outputs = model.generate( | |
input_ids=input_ids, | |
max_length=256, | |
num_beams=4, | |
no_repeat_ngram_size=2, | |
length_penalty=1.2 | |
) | |
translated_text = tokenizer.decode(outputs[0], skip_special_tokens=True) | |
return translated_text | |
# Define Gradio interface | |
iface = gr.Interface( | |
fn=translate_to_russian, | |
inputs=[gr.Textbox(lines=10, placeholder="Enter text to translate")], | |
outputs="textbox", | |
title="Translate to Russian", | |
description="Enter text in English and get the Russian translation.", | |
) | |
# Launch the interface | |
iface.launch(debug=True) | |