Spaces:
Sleeping
Sleeping
File size: 7,550 Bytes
ce4167f |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 |
import gradio as gr
from huggingface_hub import HfFolder
from transformers import MarianMTModel, MarianTokenizer
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
import torch # Add this import at the top with other imports
# Global variables to store models and tokenizers
models = {}
tokenizers = {}
token = HfFolder.get_token()
# Model configurations
MODEL_CONFIGS = {
"en-hi": {
"model_path": "rooftopcoder/opus-mt-en-hi-samanantar-finetuned",
"name": "English to Hindi"
},
"hi-en": {
"model_path": "rooftopcoder/opus-mt-hi-en-samanantar-finetuned",
"name": "Hindi to English"
},
"en-mr": {
"model_path": "rooftopcoder/opus-mt-en-mr-samanantar-finetuned",
"name": "English to Marathi"
},
"mr-en": {
"model_path": "rooftopcoder/opus-mt-mr-en-samanantar-finetuned",
"name": "Marathi to English"
}
}
# Update language codes dictionary
language_codes = {
"English": "en",
"Hindi": "hi",
"Marathi": "mr"
}
# Reverse dictionary for display purposes
language_names = {v: k for k, v in language_codes.items()}
def load_models():
try:
print("Loading models from local storage...")
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
for direction, config in MODEL_CONFIGS.items():
print(f"Loading {config['name']} model...")
tokenizers[direction] = MarianTokenizer.from_pretrained(config["model_path"], token=token)
models[direction] = MarianMTModel.from_pretrained(config["model_path"], token=token).to(device)
print("All models loaded successfully!")
return True
except Exception as e:
print(f"Error loading models: {e}")
return False
# Function to perform transliteration from English to Hindi
def transliterate_text(text, from_scheme=sanscript.ITRANS, to_scheme=sanscript.DEVANAGARI):
"""
Transliterates text from one script to another
Default is from ITRANS (Roman) to Devanagari (Hindi)
"""
try:
return transliterate(text, from_scheme, to_scheme)
except Exception as e:
print(f"Transliteration error: {e}")
return text
# Function to perform translation with MarianMT
def translate(input_text, source_lang, target_lang):
"""
Translates text using MarianMT models
"""
direction = f"{source_lang}-{target_lang}"
if direction not in models or direction not in tokenizers:
return "Error: Unsupported language pair"
if not input_text.strip():
return "Error: Please enter some text to translate."
try:
device = next(models[direction].parameters()).device
tokens = tokenizers[direction](input_text, return_tensors="pt", padding=True, truncation=True)
tokens = {k: v.to(device) for k, v in tokens.items()}
translated = models[direction].generate(**tokens)
translated = translated.cpu()
output = tokenizers[direction].batch_decode(translated, skip_special_tokens=True)
return output[0]
except Exception as e:
print(f"Translation error: {e}")
return f"Error during translation: {str(e)}"
# Helper function for handling the UI translation process
def perform_translation(input_text, source_lang, target_lang):
"""Wrapper function for the Gradio interface"""
source_code = language_codes[source_lang]
target_code = language_codes[target_lang]
# Handle transliteration for Hindi and Marathi
if source_code == "en" and target_code in ["hi", "mr"]:
common_indic_words = {
"hi": ["namaste", "dhanyavad", "kaise", "hai", "aap", "tum", "main"],
"mr": ["namaskar", "dhanyawad", "kase", "ahe", "tumhi", "mi"]
}
words = input_text.lower().split()
if any(word in common_indic_words.get(target_code, []) for word in words):
transliterated = transliterate_text(input_text)
if transliterated != input_text:
translation = translate(input_text, source_code, target_code)
return f"Transliterated: {transliterated}\n\nTranslated: {translation}"
return translate(input_text, source_code, target_code)
# Create Gradio interface
def create_interface():
with gr.Blocks(title="Neural Machine Translation - Indian Languages") as demo:
gr.Markdown("# Neural Machine Translation for Indian Languages")
gr.Markdown("Translate between English, Hindi, and Marathi using MarianMT models")
with gr.Row():
with gr.Column():
source_lang = gr.Dropdown(
choices=list(language_codes.keys()),
label="Source Language",
value="English"
)
input_text = gr.Textbox(
lines=5,
placeholder="Enter text to translate...",
label="Input Text"
)
with gr.Column():
target_lang = gr.Dropdown(
choices=list(language_codes.keys()),
label="Target Language",
value="Hindi"
)
output_text = gr.Textbox(
lines=5,
label="Translated Text",
placeholder="Translation will appear here..."
)
translate_btn = gr.Button("Translate", variant="primary")
transliterate_btn = gr.Button("Transliterate Only", variant="secondary")
# Event handlers
translate_btn.click(
fn=perform_translation,
inputs=[input_text, source_lang, target_lang],
outputs=[output_text],
api_name="translate"
)
# Direct transliteration handler (new)
def direct_transliterate(text):
if not text.strip():
return "Please enter text to transliterate"
return transliterate_text(text)
transliterate_btn.click(
fn=direct_transliterate,
inputs=[input_text],
outputs=[output_text],
api_name="transliterate"
)
# Examples for all language pairs
gr.Examples(
examples=[
["Hello, how are you?", "English", "Hindi"],
["नमस्ते, आप कैसे हैं?", "Hindi", "English"],
["Hello, how are you?", "English", "Marathi"],
["नमस्कार, तुम्ही कसे आहात?", "Marathi", "English"],
],
inputs=[input_text, source_lang, target_lang],
fn=perform_translation,
outputs=output_text,
cache_examples=True
)
gr.Markdown("""
## Model Information
This demo uses fine-tuned MarianMT models for translation between:
- English ↔️ Hindi
- English ↔️ Marathi
### Features:
- Bidirectional translation support
- Transliteration support for romanized Indic text
- Optimized models for each language pair
""")
return demo
# Launch the interface
if __name__ == "__main__":
# Load all models before launching the interface
if load_models():
demo = create_interface()
demo.launch(share=False)
else:
print("Failed to load models. Please check the model paths and try again.")
|