eaysu
initial commit
6b9283a
raw
history blame
3.84 kB
import gradio as gr
from transformers import MarianMTModel, MarianTokenizer
import torch
# Cache for storing models and tokenizers
models_cache = {}
def load_model(model_name):
"""
Load and cache the MarianMT model and tokenizer.
"""
if model_name not in models_cache:
tokenizer = MarianTokenizer.from_pretrained(model_name)
model = MarianMTModel.from_pretrained(model_name)
if torch.cuda.is_available():
model = model.to('cuda')
models_cache[model_name] = (model, tokenizer)
return models_cache[model_name]
def translate_text(model_name, text):
"""
Translate text after detecting complete sentences.
"""
if not model_name or not text.strip():
return "Select a model and provide text."
try:
# Load the selected model
model, tokenizer = load_model(model_name)
# Tokenize the input text
tokens = tokenizer(text, return_tensors="pt", padding=True)
if torch.cuda.is_available():
tokens = {k: v.to('cuda') for k, v in tokens.items()}
# Generate translated tokens
translated = model.generate(**tokens)
# Decode translation
return tokenizer.decode(translated[0], skip_special_tokens=True)
except Exception as e:
return f"Error: {str(e)}"
def process_input(model_name, text):
"""
Process user input to detect completed sentences and translate in real-time.
"""
sentences = text.strip().split('. ')
translations = []
for sentence in sentences:
if sentence.endswith('.') or sentence.endswith('!') or sentence.endswith('?'):
translations.append(translate_text(model_name, sentence))
else:
# If the sentence is incomplete, skip translation
translations.append(f"Waiting for completion: {sentence}")
return '\n'.join(translations)
# Define Gradio Interface
with gr.Blocks() as app:
gr.Markdown("## 🌍 Real-Time Sentence Translation")
gr.Markdown("### Enter text in the textbox, and it will be translated after each sentence ends!")
model_dropdown = gr.Dropdown(
label="Select Translation Model",
choices=[
"Helsinki-NLP/opus-mt-tc-big-en-tr", # English to Turkish
"Helsinki-NLP/opus-mt-tc-big-tr-en", # Turkish to English
"Helsinki-NLP/opus-mt-tc-big-en-fr", # English to French
"Helsinki-NLP/opus-mt-tc-big-fr-en", # French to English
"Helsinki-NLP/opus-mt-en-de", # English to German
"Helsinki-NLP/opus-mt-de-en", # German to English
"Helsinki-NLP/opus-mt-tc-big-en-es", # English to Spanish
"Helsinki-NLP/opus-mt-es-en", # Spanish to English
"Helsinki-NLP/opus-mt-tc-big-en-ar", # English to Arabic
"Helsinki-NLP/opus-mt-tc-big-ar-en", # Arabic to English
"Helsinki-NLP/opus-mt-en-ur", # English to Urdu
"Helsinki-NLP/opus-mt-ur-en", # Urdu to English
"Helsinki-NLP/opus-mt-en-hi", # English to Hindi
"Helsinki-NLP/opus-mt-hi-en", # Hindi to English
"Helsinki-NLP/opus-mt-en-zh", # English to Chinese
"Helsinki-NLP/opus-mt-zh-en", # Chinese to English",
],
value="Helsinki-NLP/opus-mt-tc-big-en-tr",
interactive=True
)
input_textbox = gr.Textbox(
label="Enter Text:",
placeholder="Type here...",
lines=5,
interactive=True
)
output_textbox = gr.Textbox(
label="Translated Text:",
lines=5,
interactive=False
)
input_textbox.change(
fn=process_input,
inputs=[model_dropdown, input_textbox],
outputs=[output_textbox],
)
# Launch Gradio App
app.launch()