File size: 6,724 Bytes
81f14ee b1d5be4 81f14ee b1d5be4 81f14ee b1d5be4 81f14ee b1d5be4 81f14ee b1d5be4 81f14ee b1d5be4 81f14ee b1d5be4 81f14ee |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 |
import gradio as gr
from transformers import pipeline, AutoModelForSeq2SeqLM, AutoTokenizer
import torch
import subprocess
import sys
import os
# --- Configuration ---
SYLHETI_TO_BN_MODEL = "shbhro/sylhetit5"
BN_TO_EN_MODEL = "csebuetnlp/banglat5_nmt_bn_en"
NORMALIZER_REPO = "https://github.com/csebuetnlp/normalizer.git"
# --- Helper function to install/import normalizer ---
normalizer_module = None
dummy_normalizer_flag = False # Flag to indicate if dummy is used
def dummy_normalize_func(text): # Define the dummy function clearly
raise RuntimeError("Normalizer library could not be loaded. Please check installation and logs.")
try:
from normalizer import normalize as normalize_fn_imported
normalizer_module = normalize_fn_imported
print("Normalizer imported successfully.")
except ImportError:
print(f"Normalizer library not found. Attempting to install from {NORMALIZER_REPO}...")
try:
subprocess.check_call([sys.executable, "-m", "pip", "install", f"git+{NORMALIZER_REPO}#egg=normalizer"])
from normalizer import normalize as normalize_fn_imported_after_install
normalizer_module = normalize_fn_imported_after_install
print("Normalizer installed and imported successfully after pip install.")
except Exception as e:
print(f"Failed to install or import normalizer: {e}")
print("Please ensure 'git+https://github.com/csebuetnlp/normalizer.git#egg=normalizer' is in your requirements.txt for Hugging Face Spaces.")
normalizer_module = dummy_normalize_func # Assign the actual dummy function
dummy_normalizer_flag = True
# --- Model Loading (Globally, when the script starts) ---
sylheti_to_bn_pipe = None
bn_to_en_model = None
bn_to_en_tokenizer = None
model_device = None
print("Loading translation models...")
try:
model_device_type = "cuda" if torch.cuda.is_available() else "cpu"
model_device = torch.device(model_device_type)
hf_device_param = 0 if model_device_type == "cuda" else -1 # For pipeline
print(f"Using device: {model_device_type}")
sylheti_to_bn_pipe = pipeline(
"text2text-generation",
model=SYLHETI_TO_BN_MODEL,
device=hf_device_param
)
print(f"Sylheti-to-Bengali model ({SYLHETI_TO_BN_MODEL}) loaded.")
bn_to_en_model = AutoModelForSeq2SeqLM.from_pretrained(BN_TO_EN_MODEL)
bn_to_en_tokenizer = AutoTokenizer.from_pretrained(BN_TO_EN_MODEL, use_fast=False)
bn_to_en_model.to(model_device)
print(f"Bengali-to-English model ({BN_TO_EN_MODEL}) loaded.")
except Exception as e:
print(f"FATAL: Error loading one or more models: {e}")
sylheti_to_bn_pipe = None
bn_to_en_model = None
bn_to_en_tokenizer = None
# --- Main Translation Logic ---
def translate_sylheti_to_english_gradio(sylheti_text_input):
if not sylheti_text_input.strip():
return "Please enter some Sylheti text.", ""
if not sylheti_to_bn_pipe:
return "Error: Sylheti-to-Bengali model not loaded. Check logs.", ""
if not bn_to_en_model or not bn_to_en_tokenizer:
return "Error: Bengali-to-English model not loaded. Check logs.", ""
# Check if the normalizer is the dummy function
if dummy_normalizer_flag or normalizer_module is None:
return "Error: Bengali normalizer library not available. Check logs.", ""
bengali_text_intermediate = "Error in Sylheti to Bengali step."
english_text_final = "Error in Bengali to English step."
# Step 1: Sylheti → Bengali
try:
print(f"Translating Sylheti to Bengali: '{sylheti_text_input}'")
bengali_translation_outputs = sylheti_to_bn_pipe(
sylheti_text_input,
max_length=128,
num_beams=5,
early_stopping=True
)
bengali_text_intermediate = bengali_translation_outputs[0]['generated_text']
print(f"Intermediate Bengali: '{bengali_text_intermediate}'")
except Exception as e:
print(f"Error during Sylheti to Bengali translation: {e}")
bengali_text_intermediate = f"Sylheti->Bengali Error: {str(e)}"
return bengali_text_intermediate, english_text_final
# Step 2: Bengali → English
try:
print(f"Normalizing and translating Bengali to English: '{bengali_text_intermediate}'")
# Ensure normalizer_module is callable before calling
if callable(normalizer_module):
normalized_bn_text = normalizer_module(bengali_text_intermediate)
else:
# This case should ideally be caught by the check above, but as a safeguard:
raise RuntimeError("Normalizer function is not callable.")
print(f"Normalized Bengali: '{normalized_bn_text}'")
input_ids = bn_to_en_tokenizer(
normalized_bn_text,
return_tensors="pt"
).input_ids.to(model_device)
generated_tokens = bn_to_en_model.generate(
input_ids,
max_length=128,
num_beams=5,
early_stopping=True
)
english_text_list = bn_to_en_tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
english_text_final = english_text_list[0] if english_text_list else "No English output generated."
print(f"Final English: '{english_text_final}'")
except Exception as e:
print(f"Error during Bengali to English translation: {e}")
english_text_final = f"Bengali->English Error: {str(e)}"
return bengali_text_intermediate, english_text_final
# --- Gradio Interface Definition ---
iface = gr.Interface(
fn=translate_sylheti_to_english_gradio,
inputs=gr.Textbox(
lines=4,
label="Enter Sylheti Text",
placeholder="কিতা কিতা কিনলায় তে?"
),
outputs=[
gr.Textbox(label="Intermediate Bengali Output", lines=4),
gr.Textbox(label="Final English Output", lines=4)
],
title="🌍 Sylheti to English Translator (via Bengali)",
description=(
"Translates Sylheti text to English in two steps:\n"
f"1. Sylheti → Bengali (using `{SYLHETI_TO_BN_MODEL}`)\n"
f"2. Bengali → English (using `{BN_TO_EN_MODEL}` with text normalization from `{NORMALIZER_REPO.split('/')[-1]}`)"
),
examples=[
["কিতা কিতা কিনলায় তে?"],
["তুমি কিতা কররায়?"],
["আমি ভাত খাইছি।"],
["আফনে ভালা আছনি?"]
],
allow_flagging="never",
cache_examples=False, # Explicitly disable example caching
theme=gr.themes.Soft()
)
# --- Launch the Gradio app ---
if __name__ == "__main__":
iface.launch()
|