Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -10,11 +10,26 @@ from typing import Union, Tuple
|
|
10 |
import os
|
11 |
import sys
|
12 |
from datetime import datetime, timezone
|
|
|
13 |
|
14 |
-
#
|
15 |
-
|
16 |
-
|
17 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
18 |
|
19 |
# Get Hugging Face token from environment variables
|
20 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
@@ -61,7 +76,8 @@ def load_models():
|
|
61 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
62 |
"facebook/nllb-200-distilled-600M",
|
63 |
token=HF_TOKEN,
|
64 |
-
trust_remote_code=True
|
|
|
65 |
)
|
66 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
67 |
"facebook/nllb-200-distilled-600M",
|
@@ -75,7 +91,9 @@ def load_models():
|
|
75 |
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
76 |
"google/mt5-small",
|
77 |
token=HF_TOKEN,
|
78 |
-
trust_remote_code=True
|
|
|
|
|
79 |
)
|
80 |
mt5_model = MT5ForConditionalGeneration.from_pretrained(
|
81 |
"google/mt5-small",
|
@@ -155,7 +173,6 @@ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
|
|
155 |
"""Use Gemma model to interpret context and understand regional nuances."""
|
156 |
tokenizer, model = gemma_tuple
|
157 |
|
158 |
-
# Split text into batches
|
159 |
batches = batch_process_text(text)
|
160 |
interpreted_batches = []
|
161 |
|
@@ -186,23 +203,21 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
186 |
"""Translate text using NLLB model."""
|
187 |
tokenizer, model = nllb_tuple
|
188 |
|
189 |
-
# Split text into batches
|
190 |
batches = batch_process_text(text)
|
191 |
translated_batches = []
|
192 |
|
193 |
for batch in batches:
|
194 |
-
#
|
195 |
-
|
|
|
196 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
197 |
|
198 |
-
#
|
199 |
-
target_lang_token =
|
200 |
-
target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
|
201 |
|
202 |
-
# Generate translation
|
203 |
outputs = model.generate(
|
204 |
**inputs,
|
205 |
-
forced_bos_token_id=
|
206 |
max_length=512,
|
207 |
do_sample=True,
|
208 |
temperature=0.7,
|
@@ -217,21 +232,16 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
|
|
217 |
|
218 |
@torch.no_grad()
|
219 |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
220 |
-
"""
|
221 |
-
Correct grammar using MT5 model for all supported languages.
|
222 |
-
Uses a text-to-text approach with language-specific prompts.
|
223 |
-
"""
|
224 |
tokenizer, model = mt5_tuple
|
225 |
lang_code = MT5_LANG_CODES[target_lang]
|
226 |
|
227 |
-
# Language-specific prompts for grammar correction
|
228 |
prompts = {
|
229 |
'en': "grammar: ",
|
230 |
'hi': "व्याकरण सुधार: ",
|
231 |
'mr': "व्याकरण सुधारणा: "
|
232 |
}
|
233 |
|
234 |
-
# Split text into batches
|
235 |
batches = batch_process_text(text)
|
236 |
corrected_batches = []
|
237 |
|
@@ -251,8 +261,6 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
|
251 |
)
|
252 |
|
253 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
254 |
-
|
255 |
-
# Clean up any artifacts from the model output
|
256 |
for prefix in prompts.values():
|
257 |
corrected_text = corrected_text.replace(prefix, "")
|
258 |
corrected_text = corrected_text.strip()
|
@@ -273,7 +281,7 @@ def save_as_docx(text: str) -> io.BytesIO:
|
|
273 |
return docx_buffer
|
274 |
|
275 |
def main():
|
276 |
-
st.title("Document Translation App")
|
277 |
|
278 |
# Load models
|
279 |
with st.spinner("Loading models... This may take a few minutes."):
|
@@ -306,40 +314,52 @@ def main():
|
|
306 |
index=1
|
307 |
)
|
308 |
|
309 |
-
if uploaded_file and st.button("Translate"):
|
310 |
try:
|
311 |
-
|
312 |
-
|
313 |
-
|
314 |
-
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
|
327 |
-
|
328 |
-
|
329 |
-
|
330 |
-
|
331 |
-
|
332 |
-
|
333 |
-
|
334 |
-
|
335 |
-
|
336 |
-
|
337 |
-
|
338 |
-
|
339 |
-
|
340 |
-
|
341 |
-
|
342 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
343 |
# Text file download
|
344 |
text_buffer = io.BytesIO()
|
345 |
text_buffer.write(corrected_text.encode())
|
@@ -351,7 +371,8 @@ def main():
|
|
351 |
file_name="translated_document.txt",
|
352 |
mime="text/plain"
|
353 |
)
|
354 |
-
|
|
|
355 |
# DOCX file download
|
356 |
docx_buffer = save_as_docx(corrected_text)
|
357 |
st.download_button(
|
@@ -360,7 +381,9 @@ def main():
|
|
360 |
file_name="translated_document.docx",
|
361 |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
362 |
)
|
363 |
-
|
|
|
|
|
364 |
except Exception as e:
|
365 |
st.error(f"An error occurred: {str(e)}")
|
366 |
|
|
|
10 |
import os
|
11 |
import sys
|
12 |
from datetime import datetime, timezone
|
13 |
+
import warnings
|
14 |
|
15 |
+
# Filter out specific warnings
|
16 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='transformers.convert_slow_tokenizer')
|
17 |
+
warnings.filterwarnings('ignore', category=UserWarning, module='transformers.tokenization_utils_base')
|
18 |
+
|
19 |
+
# Custom styling
|
20 |
+
st.set_page_config(
|
21 |
+
page_title="Document Translation App",
|
22 |
+
page_icon="🌐",
|
23 |
+
layout="wide"
|
24 |
+
)
|
25 |
+
|
26 |
+
# Display current information in sidebar with proper formatting
|
27 |
+
current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
|
28 |
+
st.sidebar.markdown("""
|
29 |
+
### System Information
|
30 |
+
**Current UTC Time:** {}
|
31 |
+
**User:** {}
|
32 |
+
""".format(current_time, os.environ.get('USER', 'gauravchand')))
|
33 |
|
34 |
# Get Hugging Face token from environment variables
|
35 |
HF_TOKEN = os.environ.get('HF_TOKEN')
|
|
|
76 |
nllb_tokenizer = AutoTokenizer.from_pretrained(
|
77 |
"facebook/nllb-200-distilled-600M",
|
78 |
token=HF_TOKEN,
|
79 |
+
trust_remote_code=True,
|
80 |
+
use_fast=False # Use slow tokenizer to avoid warnings
|
81 |
)
|
82 |
nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
|
83 |
"facebook/nllb-200-distilled-600M",
|
|
|
91 |
mt5_tokenizer = AutoTokenizer.from_pretrained(
|
92 |
"google/mt5-small",
|
93 |
token=HF_TOKEN,
|
94 |
+
trust_remote_code=True,
|
95 |
+
legacy=False, # Use new behavior
|
96 |
+
use_fast=False # Use slow tokenizer to avoid warnings
|
97 |
)
|
98 |
mt5_model = MT5ForConditionalGeneration.from_pretrained(
|
99 |
"google/mt5-small",
|
|
|
173 |
"""Use Gemma model to interpret context and understand regional nuances."""
|
174 |
tokenizer, model = gemma_tuple
|
175 |
|
|
|
176 |
batches = batch_process_text(text)
|
177 |
interpreted_batches = []
|
178 |
|
|
|
203 |
"""Translate text using NLLB model."""
|
204 |
tokenizer, model = nllb_tuple
|
205 |
|
|
|
206 |
batches = batch_process_text(text)
|
207 |
translated_batches = []
|
208 |
|
209 |
for batch in batches:
|
210 |
+
# Add source language token to input
|
211 |
+
batch_with_lang = f"{source_lang} {batch}"
|
212 |
+
inputs = tokenizer(batch_with_lang, return_tensors="pt", max_length=512, truncation=True)
|
213 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
214 |
|
215 |
+
# Add target language token
|
216 |
+
target_lang_token = tokenizer(target_lang, add_special_tokens=False)["input_ids"][0]
|
|
|
217 |
|
|
|
218 |
outputs = model.generate(
|
219 |
**inputs,
|
220 |
+
forced_bos_token_id=target_lang_token,
|
221 |
max_length=512,
|
222 |
do_sample=True,
|
223 |
temperature=0.7,
|
|
|
232 |
|
233 |
@torch.no_grad()
|
234 |
def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
|
235 |
+
"""Correct grammar using MT5 model for all supported languages."""
|
|
|
|
|
|
|
236 |
tokenizer, model = mt5_tuple
|
237 |
lang_code = MT5_LANG_CODES[target_lang]
|
238 |
|
|
|
239 |
prompts = {
|
240 |
'en': "grammar: ",
|
241 |
'hi': "व्याकरण सुधार: ",
|
242 |
'mr': "व्याकरण सुधारणा: "
|
243 |
}
|
244 |
|
|
|
245 |
batches = batch_process_text(text)
|
246 |
corrected_batches = []
|
247 |
|
|
|
261 |
)
|
262 |
|
263 |
corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
|
|
|
|
264 |
for prefix in prompts.values():
|
265 |
corrected_text = corrected_text.replace(prefix, "")
|
266 |
corrected_text = corrected_text.strip()
|
|
|
281 |
return docx_buffer
|
282 |
|
283 |
def main():
|
284 |
+
st.title("🌐 Document Translation App")
|
285 |
|
286 |
# Load models
|
287 |
with st.spinner("Loading models... This may take a few minutes."):
|
|
|
314 |
index=1
|
315 |
)
|
316 |
|
317 |
+
if uploaded_file and st.button("Translate", type="primary"):
|
318 |
try:
|
319 |
+
progress_bar = st.progress(0)
|
320 |
+
|
321 |
+
# Extract text
|
322 |
+
text = extract_text_from_file(uploaded_file)
|
323 |
+
progress_bar.progress(20)
|
324 |
+
|
325 |
+
# Interpret context
|
326 |
+
with st.spinner("Interpreting context..."):
|
327 |
+
interpreted_text = interpret_context(text, gemma_tuple)
|
328 |
+
progress_bar.progress(40)
|
329 |
+
|
330 |
+
# Translate
|
331 |
+
with st.spinner("Translating..."):
|
332 |
+
translated_text = translate_text(
|
333 |
+
interpreted_text,
|
334 |
+
SUPPORTED_LANGUAGES[source_language],
|
335 |
+
SUPPORTED_LANGUAGES[target_language],
|
336 |
+
nllb_tuple
|
337 |
+
)
|
338 |
+
progress_bar.progress(70)
|
339 |
+
|
340 |
+
# Grammar correction
|
341 |
+
with st.spinner("Correcting grammar..."):
|
342 |
+
corrected_text = correct_grammar(
|
343 |
+
translated_text,
|
344 |
+
SUPPORTED_LANGUAGES[target_language],
|
345 |
+
mt5_tuple
|
346 |
+
)
|
347 |
+
progress_bar.progress(90)
|
348 |
+
|
349 |
+
# Display result
|
350 |
+
st.markdown("### Translation Result")
|
351 |
+
st.text_area(
|
352 |
+
label="Translated Text",
|
353 |
+
value=corrected_text,
|
354 |
+
height=200,
|
355 |
+
key="translation_result"
|
356 |
+
)
|
357 |
+
|
358 |
+
# Download options
|
359 |
+
st.markdown("### Download Options")
|
360 |
+
col1, col2 = st.columns(2)
|
361 |
+
|
362 |
+
with col1:
|
363 |
# Text file download
|
364 |
text_buffer = io.BytesIO()
|
365 |
text_buffer.write(corrected_text.encode())
|
|
|
371 |
file_name="translated_document.txt",
|
372 |
mime="text/plain"
|
373 |
)
|
374 |
+
|
375 |
+
with col2:
|
376 |
# DOCX file download
|
377 |
docx_buffer = save_as_docx(corrected_text)
|
378 |
st.download_button(
|
|
|
381 |
file_name="translated_document.docx",
|
382 |
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
383 |
)
|
384 |
+
|
385 |
+
progress_bar.progress(100)
|
386 |
+
|
387 |
except Exception as e:
|
388 |
st.error(f"An error occurred: {str(e)}")
|
389 |
|