Spaces:
Sleeping
Sleeping
import gradio as gr | |
from huggingface_hub import HfFolder | |
from transformers import MarianMTModel, MarianTokenizer | |
from indic_transliteration import sanscript | |
from indic_transliteration.sanscript import transliterate | |
import torch # Add this import at the top with other imports | |
# Global variables to store models and tokenizers | |
models = {} | |
tokenizers = {} | |
token = HfFolder.get_token() | |
# Model configurations | |
MODEL_CONFIGS = { | |
"en-hi": { | |
"model_path": "rooftopcoder/opus-mt-en-hi-samanantar-finetuned", | |
"name": "English to Hindi" | |
}, | |
"hi-en": { | |
"model_path": "rooftopcoder/opus-mt-hi-en-samanantar-finetuned", | |
"name": "Hindi to English" | |
}, | |
"en-mr": { | |
"model_path": "rooftopcoder/opus-mt-en-mr-samanantar-finetuned", | |
"name": "English to Marathi" | |
}, | |
"mr-en": { | |
"model_path": "rooftopcoder/opus-mt-mr-en-samanantar-finetuned", | |
"name": "Marathi to English" | |
} | |
} | |
# Update language codes dictionary | |
language_codes = { | |
"English": "en", | |
"Hindi": "hi", | |
"Marathi": "mr" | |
} | |
# Reverse dictionary for display purposes | |
language_names = {v: k for k, v in language_codes.items()} | |
def load_models(): | |
try: | |
print("Loading models from local storage...") | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
print(f"Using device: {device}") | |
for direction, config in MODEL_CONFIGS.items(): | |
print(f"Loading {config['name']} model...") | |
tokenizers[direction] = MarianTokenizer.from_pretrained(config["model_path"], token=token) | |
models[direction] = MarianMTModel.from_pretrained(config["model_path"], token=token).to(device) | |
print("All models loaded successfully!") | |
return True | |
except Exception as e: | |
print(f"Error loading models: {e}") | |
return False | |
# Function to perform transliteration from English to Hindi | |
def transliterate_text(text, from_scheme=sanscript.ITRANS, to_scheme=sanscript.DEVANAGARI): | |
""" | |
Transliterates text from one script to another | |
Default is from ITRANS (Roman) to Devanagari (Hindi) | |
""" | |
try: | |
return transliterate(text, from_scheme, to_scheme) | |
except Exception as e: | |
print(f"Transliteration error: {e}") | |
return text | |
# Function to perform translation with MarianMT | |
def translate(input_text, source_lang, target_lang): | |
""" | |
Translates text using MarianMT models | |
""" | |
direction = f"{source_lang}-{target_lang}" | |
if direction not in models or direction not in tokenizers: | |
return "Error: Unsupported language pair" | |
if not input_text.strip(): | |
return "Error: Please enter some text to translate." | |
try: | |
device = next(models[direction].parameters()).device | |
tokens = tokenizers[direction](input_text, return_tensors="pt", padding=True, truncation=True) | |
tokens = {k: v.to(device) for k, v in tokens.items()} | |
translated = models[direction].generate(**tokens) | |
translated = translated.cpu() | |
output = tokenizers[direction].batch_decode(translated, skip_special_tokens=True) | |
return output[0] | |
except Exception as e: | |
print(f"Translation error: {e}") | |
return f"Error during translation: {str(e)}" | |
# Helper function for handling the UI translation process | |
def perform_translation(input_text, source_lang, target_lang): | |
"""Wrapper function for the Gradio interface""" | |
source_code = language_codes[source_lang] | |
target_code = language_codes[target_lang] | |
# Handle transliteration for Hindi and Marathi | |
if source_code == "en" and target_code in ["hi", "mr"]: | |
common_indic_words = { | |
"hi": ["namaste", "dhanyavad", "kaise", "hai", "aap", "tum", "main"], | |
"mr": ["namaskar", "dhanyawad", "kase", "ahe", "tumhi", "mi"] | |
} | |
words = input_text.lower().split() | |
if any(word in common_indic_words.get(target_code, []) for word in words): | |
transliterated = transliterate_text(input_text) | |
if transliterated != input_text: | |
translation = translate(input_text, source_code, target_code) | |
return f"Transliterated: {transliterated}\n\nTranslated: {translation}" | |
return translate(input_text, source_code, target_code) | |
# Create Gradio interface | |
def create_interface(): | |
with gr.Blocks(title="Neural Machine Translation - Indian Languages") as demo: | |
gr.Markdown("# Neural Machine Translation for Indian Languages") | |
gr.Markdown("Translate between English, Hindi, and Marathi using MarianMT models") | |
with gr.Row(): | |
with gr.Column(): | |
source_lang = gr.Dropdown( | |
choices=list(language_codes.keys()), | |
label="Source Language", | |
value="English" | |
) | |
input_text = gr.Textbox( | |
lines=5, | |
placeholder="Enter text to translate...", | |
label="Input Text" | |
) | |
with gr.Column(): | |
target_lang = gr.Dropdown( | |
choices=list(language_codes.keys()), | |
label="Target Language", | |
value="Hindi" | |
) | |
output_text = gr.Textbox( | |
lines=5, | |
label="Translated Text", | |
placeholder="Translation will appear here..." | |
) | |
translate_btn = gr.Button("Translate", variant="primary") | |
transliterate_btn = gr.Button("Transliterate Only", variant="secondary") | |
# Event handlers | |
translate_btn.click( | |
fn=perform_translation, | |
inputs=[input_text, source_lang, target_lang], | |
outputs=[output_text], | |
api_name="translate" | |
) | |
# Direct transliteration handler (new) | |
def direct_transliterate(text): | |
if not text.strip(): | |
return "Please enter text to transliterate" | |
return transliterate_text(text) | |
transliterate_btn.click( | |
fn=direct_transliterate, | |
inputs=[input_text], | |
outputs=[output_text], | |
api_name="transliterate" | |
) | |
# Examples for all language pairs | |
gr.Examples( | |
examples=[ | |
["Hello, how are you?", "English", "Hindi"], | |
["नमस्ते, आप कैसे हैं?", "Hindi", "English"], | |
["Hello, how are you?", "English", "Marathi"], | |
["नमस्कार, तुम्ही कसे आहात?", "Marathi", "English"], | |
], | |
inputs=[input_text, source_lang, target_lang], | |
fn=perform_translation, | |
outputs=output_text, | |
cache_examples=True | |
) | |
gr.Markdown(""" | |
## Model Information | |
This demo uses fine-tuned MarianMT models for translation between: | |
- English ↔️ Hindi | |
- English ↔️ Marathi | |
### Features: | |
- Bidirectional translation support | |
- Transliteration support for romanized Indic text | |
- Optimized models for each language pair | |
""") | |
return demo | |
# Launch the interface | |
if __name__ == "__main__": | |
# Load all models before launching the interface | |
if load_models(): | |
demo = create_interface() | |
demo.launch(share=False) | |
else: | |
print("Failed to load models. Please check the model paths and try again.") | |