File size: 7,550 Bytes
ce4167f
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import gradio as gr
from huggingface_hub import HfFolder
from transformers import MarianMTModel, MarianTokenizer
from indic_transliteration import sanscript
from indic_transliteration.sanscript import transliterate
import torch  # Add this import at the top with other imports

# Global variables to store models and tokenizers
models = {}
tokenizers = {}
token = HfFolder.get_token()

# Model configurations
MODEL_CONFIGS = {
    "en-hi": {
        "model_path": "rooftopcoder/opus-mt-en-hi-samanantar-finetuned",
        "name": "English to Hindi"
    },
    "hi-en": {
        "model_path": "rooftopcoder/opus-mt-hi-en-samanantar-finetuned",
        "name": "Hindi to English"
    },
    "en-mr": {
        "model_path": "rooftopcoder/opus-mt-en-mr-samanantar-finetuned",
        "name": "English to Marathi"
    },
    "mr-en": {
        "model_path": "rooftopcoder/opus-mt-mr-en-samanantar-finetuned",
        "name": "Marathi to English"
    }
}

# Update language codes dictionary
language_codes = {
    "English": "en",
    "Hindi": "hi",
    "Marathi": "mr"
}

# Reverse dictionary for display purposes
language_names = {v: k for k, v in language_codes.items()}

def load_models():
    try:
        print("Loading models from local storage...")
        device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
        print(f"Using device: {device}")

        for direction, config in MODEL_CONFIGS.items():
            print(f"Loading {config['name']} model...")
            tokenizers[direction] = MarianTokenizer.from_pretrained(config["model_path"], token=token)
            models[direction] = MarianMTModel.from_pretrained(config["model_path"], token=token).to(device)

        print("All models loaded successfully!")
        return True
    except Exception as e:
        print(f"Error loading models: {e}")
        return False

# Function to perform transliteration from English to Hindi
def transliterate_text(text, from_scheme=sanscript.ITRANS, to_scheme=sanscript.DEVANAGARI):
    """
    Transliterates text from one script to another
    Default is from ITRANS (Roman) to Devanagari (Hindi)
    """
    try:
        return transliterate(text, from_scheme, to_scheme)
    except Exception as e:
        print(f"Transliteration error: {e}")
        return text

# Function to perform translation with MarianMT
def translate(input_text, source_lang, target_lang):
    """
    Translates text using MarianMT models
    """
    direction = f"{source_lang}-{target_lang}"
    if direction not in models or direction not in tokenizers:
        return "Error: Unsupported language pair"

    if not input_text.strip():
        return "Error: Please enter some text to translate."

    try:
        device = next(models[direction].parameters()).device
        tokens = tokenizers[direction](input_text, return_tensors="pt", padding=True, truncation=True)
        tokens = {k: v.to(device) for k, v in tokens.items()}

        translated = models[direction].generate(**tokens)
        translated = translated.cpu()
        output = tokenizers[direction].batch_decode(translated, skip_special_tokens=True)
        return output[0]
    except Exception as e:
        print(f"Translation error: {e}")
        return f"Error during translation: {str(e)}"

# Helper function for handling the UI translation process
def perform_translation(input_text, source_lang, target_lang):
    """Wrapper function for the Gradio interface"""
    source_code = language_codes[source_lang]
    target_code = language_codes[target_lang]

    # Handle transliteration for Hindi and Marathi
    if source_code == "en" and target_code in ["hi", "mr"]:
        common_indic_words = {
            "hi": ["namaste", "dhanyavad", "kaise", "hai", "aap", "tum", "main"],
            "mr": ["namaskar", "dhanyawad", "kase", "ahe", "tumhi", "mi"]
        }

        words = input_text.lower().split()
        if any(word in common_indic_words.get(target_code, []) for word in words):
            transliterated = transliterate_text(input_text)
            if transliterated != input_text:
                translation = translate(input_text, source_code, target_code)
                return f"Transliterated: {transliterated}\n\nTranslated: {translation}"

    return translate(input_text, source_code, target_code)

# Create Gradio interface
def create_interface():
    with gr.Blocks(title="Neural Machine Translation - Indian Languages") as demo:
        gr.Markdown("# Neural Machine Translation for Indian Languages")
        gr.Markdown("Translate between English, Hindi, and Marathi using MarianMT models")

        with gr.Row():
            with gr.Column():
                source_lang = gr.Dropdown(
                    choices=list(language_codes.keys()),
                    label="Source Language",
                    value="English"
                )
                input_text = gr.Textbox(
                    lines=5,
                    placeholder="Enter text to translate...",
                    label="Input Text"
                )

            with gr.Column():
                target_lang = gr.Dropdown(
                    choices=list(language_codes.keys()),
                    label="Target Language",
                    value="Hindi"
                )
                output_text = gr.Textbox(
                    lines=5,
                    label="Translated Text",
                    placeholder="Translation will appear here..."
                )

        translate_btn = gr.Button("Translate", variant="primary")
        transliterate_btn = gr.Button("Transliterate Only", variant="secondary")

        # Event handlers
        translate_btn.click(
            fn=perform_translation,
            inputs=[input_text, source_lang, target_lang],
            outputs=[output_text],
            api_name="translate"
        )

        # Direct transliteration handler (new)
        def direct_transliterate(text):
            if not text.strip():
                return "Please enter text to transliterate"
            return transliterate_text(text)

        transliterate_btn.click(
            fn=direct_transliterate,
            inputs=[input_text],
            outputs=[output_text],
            api_name="transliterate"
        )

        # Examples for all language pairs
        gr.Examples(
            examples=[
                ["Hello, how are you?", "English", "Hindi"],
                ["नमस्ते, आप कैसे हैं?", "Hindi", "English"],
                ["Hello, how are you?", "English", "Marathi"],
                ["नमस्कार, तुम्ही कसे आहात?", "Marathi", "English"],
            ],
            inputs=[input_text, source_lang, target_lang],
            fn=perform_translation,
            outputs=output_text,
            cache_examples=True
        )

        gr.Markdown("""
        ## Model Information

        This demo uses fine-tuned MarianMT models for translation between:
        - English ↔️ Hindi
        - English ↔️ Marathi

        ### Features:
        - Bidirectional translation support
        - Transliteration support for romanized Indic text
        - Optimized models for each language pair
        """)

    return demo

# Launch the interface
if __name__ == "__main__":
    # Load all models before launching the interface
    if load_models():
        demo = create_interface()
        demo.launch(share=False)
    else:
        print("Failed to load models. Please check the model paths and try again.")