TranslateTa / app.py
johnybanda's picture
Create app.py
9173e6e verified
import gradio as gr
import torch
from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
import re
# GPU check
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# Load model and tokenizer
model_name = "facebook/mbart-large-50-many-to-many-mmt"
model = MBartForConditionalGeneration.from_pretrained(model_name).to(device)
tokenizer = MBart50TokenizerFast.from_pretrained(model_name)
# Set languages
src_lang, tgt_lang = "en_XX", "ta_IN"
# Define a set of technical terms (expand this list as needed)
technical_terms = {
"machine translation", "natural language processing", "nlp", "transformer architecture",
"machine learning", "deep learning", "artificial intelligence", "ai", "neural network",
"algorithms", "data science", "big data", "cloud computing", "internet of things", "iot",
"blockchain", "cybersecurity", "virtual reality", "vr", "augmented reality", "ar",
"robotics", "automation", "quantum computing", "5g", "edge computing", "devops",
"microservices", "api", "serverless", "container", "docker", "kubernetes", "ml",
"computer vision", "natural language understanding", "nlu", "speech recognition",
"sentiment analysis", "chatbot", "reinforcement learning", "supervised learning",
"unsupervised learning", "convolutional neural network", "cnn", "recurrent neural network", "rnn",
"long short-term memory", "lstm", "generative adversarial network", "gan",
"transfer learning", "federated learning", "explainable ai", "xai"
}
def preprocess_text(text):
# Wrap technical terms with special tokens
for term in sorted(technical_terms, key=len, reverse=True):
pattern = re.compile(r'\b' + re.escape(term) + r'\b', re.IGNORECASE)
text = pattern.sub(lambda m: f"<keep>{m.group()}</keep>", text)
return text
def postprocess_text(text):
# Replace special tokens with original terms
return re.sub(r'<keep>(.*?)</keep>', r'**\1**', text)
def translate(text, src_lang=src_lang, tgt_lang=tgt_lang):
# Preprocess the text
preprocessed_text = preprocess_text(text)
# Tokenize the preprocessed text
inputs = tokenizer(preprocessed_text, return_tensors="pt", padding=True, truncation=True, max_length=512)
inputs = {k: v.to(device) for k, v in inputs.items()}
# Generate translation
translated = model.generate(
**inputs,
forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang],
max_length=512,
num_beams=5,
length_penalty=1.0,
early_stopping=True
)
# Decode the generated tokens
translated_text = tokenizer.batch_decode(translated, skip_special_tokens=True)[0]
# Postprocess the translated text
return postprocess_text(translated_text)
# Gradio interface
def gradio_translate(text):
return translate(text)
iface = gr.Interface(
fn=gradio_translate,
inputs=gr.Textbox(lines=5, label="English Text"),
outputs=gr.Textbox(lines=5, label="Tamil Translation"),
title="English to Tamil Translation with Technical Terms Preserved",
description="This app translates English text to Tamil while preserving technical terms."
)
iface.launch()