File size: 6,724 Bytes
81f14ee
 
 
 
 
 
 
 
 
 
 
 
 
 
b1d5be4
 
 
 
 
81f14ee
 
 
 
 
 
 
 
 
 
 
 
 
 
b1d5be4
 
 
81f14ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1d5be4
 
 
81f14ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1d5be4
81f14ee
 
 
 
b1d5be4
 
 
 
 
 
 
81f14ee
 
 
 
 
b1d5be4
81f14ee
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
b1d5be4
 
81f14ee
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
import gradio as gr
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import subprocess
import sys
import os

# --- Configuration ---
SYLHETI_TO_BN_MODEL = "shbhro/sylhetit5"
BN_TO_EN_MODEL = "csebuetnlp/banglat5_nmt_bn_en"
NORMALIZER_REPO = "https://github.com/csebuetnlp/normalizer.git"

# --- Helper function to install/import normalizer ---
normalizer_module = None
dummy_normalizer_flag = False # Flag to indicate if dummy is used

def dummy_normalize_func(text): # Define the dummy function clearly
    raise RuntimeError("Normalizer library could not be loaded. Please check installation and logs.")

try:
    from normalizer import normalize as normalize_fn_imported
    normalizer_module = normalize_fn_imported
    print("Normalizer imported successfully.")
except ImportError:
    print(f"Normalizer library not found. Attempting to install from {NORMALIZER_REPO}...")
    try:
        subprocess.check_call([sys.executable, "-m", "pip", "install", f"git+{NORMALIZER_REPO}#egg=normalizer"])
        from normalizer import normalize as normalize_fn_imported_after_install
        normalizer_module = normalize_fn_imported_after_install
        print("Normalizer installed and imported successfully after pip install.")
    except Exception as e:
        print(f"Failed to install or import normalizer: {e}")
        print("Please ensure 'git+https://github.com/csebuetnlp/normalizer.git#egg=normalizer' is in your requirements.txt for Hugging Face Spaces.")
        normalizer_module = dummy_normalize_func # Assign the actual dummy function
        dummy_normalizer_flag = True


# --- Model Loading (Globally, when the script starts) ---
sylheti_to_bn_pipe = None
bn_to_en_model = None
bn_to_en_tokenizer = None
model_device = None

print("Loading translation models...")
try:
    model_device_type = "cuda" if torch.cuda.is_available() else "cpu"
    model_device = torch.device(model_device_type)
    hf_device_param = 0 if model_device_type == "cuda" else -1 # For pipeline

    print(f"Using device: {model_device_type}")

    sylheti_to_bn_pipe = pipeline(
        "text2text-generation",
        model=SYLHETI_TO_BN_MODEL,
        device=hf_device_param
    )
    print(f"Sylheti-to-Bengali model ({SYLHETI_TO_BN_MODEL}) loaded.")

    bn_to_en_model = AutoModelForSeq2SeqLM.from_pretrained(BN_TO_EN_MODEL)
    bn_to_en_tokenizer = AutoTokenizer.from_pretrained(BN_TO_EN_MODEL, use_fast=False)
    bn_to_en_model.to(model_device)
    print(f"Bengali-to-English model ({BN_TO_EN_MODEL}) loaded.")

except Exception as e:
    print(f"FATAL: Error loading one or more models: {e}")
    sylheti_to_bn_pipe = None
    bn_to_en_model = None
    bn_to_en_tokenizer = None

# --- Main Translation Logic ---
def translate_sylheti_to_english_gradio(sylheti_text_input):
    if not sylheti_text_input.strip():
        return "Please enter some Sylheti text.", ""

    if not sylheti_to_bn_pipe:
        return "Error: Sylheti-to-Bengali model not loaded. Check logs.", ""
    if not bn_to_en_model or not bn_to_en_tokenizer:
        return "Error: Bengali-to-English model not loaded. Check logs.", ""
    
    # Check if the normalizer is the dummy function
    if dummy_normalizer_flag or normalizer_module is None:
        return "Error: Bengali normalizer library not available. Check logs.", ""


    bengali_text_intermediate = "Error in Sylheti to Bengali step."
    english_text_final = "Error in Bengali to English step."

    # Step 1: Sylheti → Bengali
    try:
        print(f"Translating Sylheti to Bengali: '{sylheti_text_input}'")
        bengali_translation_outputs = sylheti_to_bn_pipe(
            sylheti_text_input,
            max_length=128,
            num_beams=5,
            early_stopping=True
        )
        bengali_text_intermediate = bengali_translation_outputs[0]['generated_text']
        print(f"Intermediate Bengali: '{bengali_text_intermediate}'")
    except Exception as e:
        print(f"Error during Sylheti to Bengali translation: {e}")
        bengali_text_intermediate = f"Sylheti->Bengali Error: {str(e)}"
        return bengali_text_intermediate, english_text_final

    # Step 2: Bengali → English
    try:
        print(f"Normalizing and translating Bengali to English: '{bengali_text_intermediate}'")
        # Ensure normalizer_module is callable before calling
        if callable(normalizer_module):
            normalized_bn_text = normalizer_module(bengali_text_intermediate)
        else:
            # This case should ideally be caught by the check above, but as a safeguard:
            raise RuntimeError("Normalizer function is not callable.")
            
        print(f"Normalized Bengali: '{normalized_bn_text}'")

        input_ids = bn_to_en_tokenizer(
            normalized_bn_text,
            return_tensors="pt"
        ).input_ids.to(model_device)

        generated_tokens = bn_to_en_model.generate(
            input_ids,
            max_length=128,
            num_beams=5,
            early_stopping=True
        )
        english_text_list = bn_to_en_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
        english_text_final = english_text_list[0] if english_text_list else "No English output generated."
        print(f"Final English: '{english_text_final}'")
    except Exception as e:
        print(f"Error during Bengali to English translation: {e}")
        english_text_final = f"Bengali->English Error: {str(e)}"

    return bengali_text_intermediate, english_text_final

# --- Gradio Interface Definition ---
iface = gr.Interface(
    fn=translate_sylheti_to_english_gradio,
    inputs=gr.Textbox(
        lines=4,
        label="Enter Sylheti Text",
        placeholder="কিতা কিতা কিনলায় তে?"
    ),
    outputs=[
        gr.Textbox(label="Intermediate Bengali Output", lines=4),
        gr.Textbox(label="Final English Output", lines=4)
    ],
    title="🌍 Sylheti to English Translator (via Bengali)",
    description=(
        "Translates Sylheti text to English in two steps:\n"
        f"1. Sylheti → Bengali (using `{SYLHETI_TO_BN_MODEL}`)\n"
        f"2. Bengali → English (using `{BN_TO_EN_MODEL}` with text normalization from `{NORMALIZER_REPO.split('/')[-1]}`)"
    ),
    examples=[
        ["কিতা কিতা কিনলায় তে?"],
        ["তুমি কিতা কররায়?"],
        ["আমি ভাত খাইছি।"],
        ["আফনে ভালা আছনি?"]
    ],
    allow_flagging="never",
    cache_examples=False, # Explicitly disable example caching
    theme=gr.themes.Soft()
)

# --- Launch the Gradio app ---
if __name__ == "__main__":
    iface.launch()