Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -23,7 +23,7 @@ st.set_page_config(
|
|
23 |
layout="wide"
|
24 |
)
|
25 |
|
26 |
-
# Display current information in sidebar
|
27 |
current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
|
28 |
st.sidebar.markdown("""
|
29 |
### System Information
|
@@ -76,8 +76,8 @@ def load_models():
|
|
76 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
77 |
"facebook/nllb-200-distilled-600M",
|
78 |
token=HF_TOKEN,
|
79 |
-
|
80 |
-
|
81 |
)
|
82 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
83 |
"facebook/nllb-200-distilled-600M",
|
@@ -89,14 +89,12 @@ def load_models():
|
|
89 |
|
90 |
# Load MT5 model for grammar correction
|
91 |
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
92 |
-
"google/mt5-
|
93 |
token=HF_TOKEN,
|
94 |
-
trust_remote_code=True
|
95 |
-
legacy=False, # Use new behavior
|
96 |
-
use_fast=False # Use slow tokenizer to avoid warnings
|
97 |
)
|
98 |
mt5_model = MT5ForConditionalGeneration.from_pretrained(
|
99 |
-
"google/mt5-
|
100 |
token=HF_TOKEN,
|
101 |
torch_dtype=torch.float16,
|
102 |
device_map="auto" if torch.cuda.is_available() else None,
|
@@ -177,9 +175,7 @@ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
|
|
177 |
interpreted_batches = []
|
178 |
|
179 |
for batch in batches:
|
180 |
-
prompt = f"""Analyze
|
181 |
-
maintaining the core meaning while identifying any idiomatic expressions or
|
182 |
-
cultural references: {batch}"""
|
183 |
|
184 |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
185 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
@@ -194,6 +190,8 @@ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
|
|
194 |
)
|
195 |
|
196 |
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
197 |
interpreted_batches.append(interpreted_text)
|
198 |
|
199 |
return " ".join(interpreted_batches)
|
@@ -207,17 +205,12 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
207 |
translated_batches = []
|
208 |
|
209 |
for batch in batches:
|
210 |
-
|
211 |
-
batch_with_lang = f"{source_lang} {batch}"
|
212 |
-
inputs = tokenizer(batch_with_lang, return_tensors="pt", max_length=512, truncation=True)
|
213 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
214 |
|
215 |
-
# Add target language token
|
216 |
-
target_lang_token = tokenizer(target_lang, add_special_tokens=False)["input_ids"][0]
|
217 |
-
|
218 |
outputs = model.generate(
|
219 |
**inputs,
|
220 |
-
forced_bos_token_id=
|
221 |
max_length=512,
|
222 |
do_sample=True,
|
223 |
temperature=0.7,
|
@@ -236,35 +229,36 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
|
236 |
tokenizer, model = mt5_tuple
|
237 |
lang_code = MT5_LANG_CODES[target_lang]
|
238 |
|
|
|
239 |
prompts = {
|
240 |
-
'en': "grammar: ",
|
241 |
-
'hi': "
|
242 |
-
'mr': "
|
243 |
}
|
244 |
|
245 |
batches = batch_process_text(text)
|
246 |
corrected_batches = []
|
247 |
|
248 |
for batch in batches:
|
249 |
-
|
250 |
-
|
|
|
251 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
252 |
|
253 |
outputs = model.generate(
|
254 |
**inputs,
|
255 |
max_length=512,
|
256 |
num_beams=5,
|
257 |
-
|
258 |
-
|
259 |
-
|
260 |
-
num_return_sequences=1
|
261 |
)
|
262 |
|
263 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
264 |
for prefix in prompts.values():
|
265 |
corrected_text = corrected_text.replace(prefix, "")
|
266 |
-
corrected_text = corrected_text.strip()
|
267 |
-
|
268 |
corrected_batches.append(corrected_text)
|
269 |
|
270 |
return " ".join(corrected_batches)
|
|
|
23 |
layout="wide"
|
24 |
)
|
25 |
|
26 |
+
# Display current information in sidebar
|
27 |
current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
|
28 |
st.sidebar.markdown("""
|
29 |
### System Information
|
|
|
76 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
77 |
"facebook/nllb-200-distilled-600M",
|
78 |
token=HF_TOKEN,
|
79 |
+
src_lang="eng_Latn",
|
80 |
+
trust_remote_code=True
|
81 |
)
|
82 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
83 |
"facebook/nllb-200-distilled-600M",
|
|
|
89 |
|
90 |
# Load MT5 model for grammar correction
|
91 |
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
92 |
+
"google/mt5-base", # Changed to base model for better performance
|
93 |
token=HF_TOKEN,
|
94 |
+
trust_remote_code=True
|
|
|
|
|
95 |
)
|
96 |
mt5_model = MT5ForConditionalGeneration.from_pretrained(
|
97 |
+
"google/mt5-base", # Changed to base model for better performance
|
98 |
token=HF_TOKEN,
|
99 |
torch_dtype=torch.float16,
|
100 |
device_map="auto" if torch.cuda.is_available() else None,
|
|
|
175 |
interpreted_batches = []
|
176 |
|
177 |
for batch in batches:
|
178 |
+
prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
|
|
|
|
|
179 |
|
180 |
inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
|
181 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
|
|
190 |
)
|
191 |
|
192 |
interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
193 |
+
# Remove the prompt from the output
|
194 |
+
interpreted_text = interpreted_text.replace(prompt, "").strip()
|
195 |
interpreted_batches.append(interpreted_text)
|
196 |
|
197 |
return " ".join(interpreted_batches)
|
|
|
205 |
translated_batches = []
|
206 |
|
207 |
for batch in batches:
|
208 |
+
inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
|
|
|
|
|
209 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
210 |
|
|
|
|
|
|
|
211 |
outputs = model.generate(
|
212 |
**inputs,
|
213 |
+
forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
|
214 |
max_length=512,
|
215 |
do_sample=True,
|
216 |
temperature=0.7,
|
|
|
229 |
tokenizer, model = mt5_tuple
|
230 |
lang_code = MT5_LANG_CODES[target_lang]
|
231 |
|
232 |
+
# Language-specific prompts for grammar correction
|
233 |
prompts = {
|
234 |
+
'en': "Fix grammar: ",
|
235 |
+
'hi': "व्याकरण: ",
|
236 |
+
'mr': "व्याकरण: "
|
237 |
}
|
238 |
|
239 |
batches = batch_process_text(text)
|
240 |
corrected_batches = []
|
241 |
|
242 |
for batch in batches:
|
243 |
+
# Prepare input with target language prefix
|
244 |
+
input_text = f"{prompts[lang_code]}{batch}"
|
245 |
+
inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
|
246 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
247 |
|
248 |
outputs = model.generate(
|
249 |
**inputs,
|
250 |
max_length=512,
|
251 |
num_beams=5,
|
252 |
+
length_penalty=1.0,
|
253 |
+
early_stopping=True,
|
254 |
+
do_sample=False # Disable sampling for more stable output
|
|
|
255 |
)
|
256 |
|
257 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
258 |
+
# Clean up the output
|
259 |
for prefix in prompts.values():
|
260 |
corrected_text = corrected_text.replace(prefix, "")
|
261 |
+
corrected_text = corrected_text.replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
|
|
|
262 |
corrected_batches.append(corrected_text)
|
263 |
|
264 |
return " ".join(corrected_batches)
|