Spaces:
Build error
Build error
Update app.py
Browse files
app.py
CHANGED
@@ -5,13 +5,12 @@ import io
|
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
|
6 |
import torch
|
7 |
from pathlib import Path
|
8 |
-
import tempfile
|
9 |
from typing import Union, Tuple, List, Dict
|
10 |
import os
|
11 |
import sys
|
12 |
from datetime import datetime, timezone
|
13 |
import warnings
|
14 |
-
import
|
15 |
|
16 |
# Filter warnings
|
17 |
warnings.filterwarnings('ignore', category=UserWarning)
|
@@ -105,7 +104,6 @@ class TextBatcher:
|
|
105 |
@staticmethod
|
106 |
def _split_into_sentences(text: str) -> List[str]:
|
107 |
"""Split text into sentences with improved boundary detection"""
|
108 |
-
# Basic sentence boundary detection
|
109 |
delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
|
110 |
sentences = []
|
111 |
current = text
|
@@ -131,14 +129,12 @@ class ModelManager:
|
|
131 |
try:
|
132 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
133 |
|
134 |
-
# Load models with improved error handling
|
135 |
models = {
|
136 |
"gemma": ModelManager._load_gemma_model(),
|
137 |
"nllb": ModelManager._load_nllb_model(),
|
138 |
"mt5": ModelManager._load_mt5_model()
|
139 |
}
|
140 |
|
141 |
-
# Move models to appropriate device
|
142 |
if not torch.cuda.is_available():
|
143 |
for model_tuple in models.values():
|
144 |
model_tuple[1].to(device)
|
@@ -208,7 +204,6 @@ class TranslationPipeline:
|
|
208 |
|
209 |
@torch.no_grad()
|
210 |
def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
|
211 |
-
# Split text into manageable batches
|
212 |
batches = TextBatcher.batch_process_text(text)
|
213 |
final_results = []
|
214 |
|
@@ -231,10 +226,11 @@ class TranslationPipeline:
|
|
231 |
|
232 |
final_results.append(corrected)
|
233 |
|
234 |
-
|
|
|
|
|
235 |
|
236 |
def _understand_context(self, text: str) -> str:
|
237 |
-
"""Enhanced context understanding using Gemma model"""
|
238 |
tokenizer, model = self.models["gemma"]
|
239 |
|
240 |
prompt = f"""Analyze and provide context for translation:
|
@@ -267,12 +263,9 @@ Provide a clear and concise interpretation that maintains:
|
|
267 |
return context.replace(prompt, "").strip()
|
268 |
|
269 |
def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
|
270 |
-
"""Enhanced translation using NLLB model with context awareness"""
|
271 |
tokenizer, model = self.models["nllb"]
|
272 |
|
273 |
-
source_lang_token = f"___{source_lang}___"
|
274 |
target_lang_token = f"___{target_lang}___"
|
275 |
-
|
276 |
inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
277 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
278 |
|
@@ -293,7 +286,6 @@ Provide a clear and concise interpretation that maintains:
|
|
293 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
294 |
|
295 |
def _correct_grammar(self, text: str, target_lang: str) -> str:
|
296 |
-
"""Enhanced grammar correction using MT5 model"""
|
297 |
tokenizer, model = self.models["mt5"]
|
298 |
lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
|
299 |
prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
|
@@ -313,9 +305,20 @@ Provide a clear and concise interpretation that maintains:
|
|
313 |
)
|
314 |
|
315 |
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
316 |
-
|
317 |
-
|
318 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
319 |
|
320 |
class DocumentExporter:
|
321 |
"""Handles document export operations"""
|
@@ -328,31 +331,23 @@ class DocumentExporter:
|
|
328 |
buffer = io.BytesIO()
|
329 |
doc.save(buffer)
|
330 |
buffer.seek(0)
|
331 |
-
|
332 |
-
return buffer
|
333 |
-
|
334 |
-
@staticmethod
|
335 |
-
def save_as_text(text: str) -> io.BytesIO:
|
336 |
-
buffer = io.BytesIO()
|
337 |
-
buffer.write(text.encode())
|
338 |
-
buffer.seek(0)
|
339 |
return buffer
|
340 |
|
341 |
def main():
|
342 |
st.title("🌐 Enhanced Document Translation App")
|
343 |
|
344 |
-
# Check for HF_TOKEN
|
345 |
-
if not os.environ.get('HF_TOKEN'):
|
346 |
-
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
|
347 |
-
st.stop()
|
348 |
-
|
349 |
# Display system info
|
350 |
st.sidebar.markdown(f"""
|
351 |
### System Information
|
352 |
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
|
353 |
-
**User:** {os.environ.get('USER', '
|
354 |
""")
|
355 |
|
|
|
|
|
|
|
|
|
|
|
356 |
# Load models
|
357 |
with st.spinner("Loading models... This may take a few minutes."):
|
358 |
try:
|
@@ -412,25 +407,14 @@ def main():
|
|
412 |
key="translation_result"
|
413 |
)
|
414 |
|
415 |
-
# Download
|
416 |
-
st.markdown("### Download
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
-
|
423 |
-
file_name="translated_document.txt",
|
424 |
-
mime="text/plain"
|
425 |
-
)
|
426 |
-
|
427 |
-
with col2:
|
428 |
-
st.download_button(
|
429 |
-
label="Download as DOCX",
|
430 |
-
data=DocumentExporter.save_as_docx(final_text),
|
431 |
-
file_name="translated_document.docx",
|
432 |
-
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
433 |
-
)
|
434 |
|
435 |
status_text.text("Translation completed successfully!")
|
436 |
progress_bar.progress(100)
|
|
|
5 |
from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
|
6 |
import torch
|
7 |
from pathlib import Path
|
|
|
8 |
from typing import Union, Tuple, List, Dict
|
9 |
import os
|
10 |
import sys
|
11 |
from datetime import datetime, timezone
|
12 |
import warnings
|
13 |
+
import re
|
14 |
|
15 |
# Filter warnings
|
16 |
warnings.filterwarnings('ignore', category=UserWarning)
|
|
|
104 |
@staticmethod
|
105 |
def _split_into_sentences(text: str) -> List[str]:
|
106 |
"""Split text into sentences with improved boundary detection"""
|
|
|
107 |
delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
|
108 |
sentences = []
|
109 |
current = text
|
|
|
129 |
try:
|
130 |
device = "cuda" if torch.cuda.is_available() else "cpu"
|
131 |
|
|
|
132 |
models = {
|
133 |
"gemma": ModelManager._load_gemma_model(),
|
134 |
"nllb": ModelManager._load_nllb_model(),
|
135 |
"mt5": ModelManager._load_mt5_model()
|
136 |
}
|
137 |
|
|
|
138 |
if not torch.cuda.is_available():
|
139 |
for model_tuple in models.values():
|
140 |
model_tuple[1].to(device)
|
|
|
204 |
|
205 |
@torch.no_grad()
|
206 |
def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
|
|
|
207 |
batches = TextBatcher.batch_process_text(text)
|
208 |
final_results = []
|
209 |
|
|
|
226 |
|
227 |
final_results.append(corrected)
|
228 |
|
229 |
+
# Clean up the final text
|
230 |
+
final_text = " ".join(final_results)
|
231 |
+
return self._clean_text(final_text)
|
232 |
|
233 |
def _understand_context(self, text: str) -> str:
|
|
|
234 |
tokenizer, model = self.models["gemma"]
|
235 |
|
236 |
prompt = f"""Analyze and provide context for translation:
|
|
|
263 |
return context.replace(prompt, "").strip()
|
264 |
|
265 |
def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
|
|
|
266 |
tokenizer, model = self.models["nllb"]
|
267 |
|
|
|
268 |
target_lang_token = f"___{target_lang}___"
|
|
|
269 |
inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
|
270 |
inputs = {k: v.to(model.device) for k, v in inputs.items()}
|
271 |
|
|
|
286 |
return tokenizer.decode(outputs[0], skip_special_tokens=True)
|
287 |
|
288 |
def _correct_grammar(self, text: str, target_lang: str) -> str:
|
|
|
289 |
tokenizer, model = self.models["mt5"]
|
290 |
lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
|
291 |
prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
|
|
|
305 |
)
|
306 |
|
307 |
corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
|
308 |
+
return self._clean_text(corrected.replace(prompt, "").strip())
|
309 |
+
|
310 |
+
def _clean_text(self, text: str) -> str:
|
311 |
+
"""Clean up the text by removing special tokens and fixing formatting"""
|
312 |
+
# Remove MT5 special tokens
|
313 |
+
text = re.sub(r'<extra_id_\d+>', '', text)
|
314 |
+
|
315 |
+
# Fix multiple spaces
|
316 |
+
text = re.sub(r'\s+', ' ', text)
|
317 |
+
|
318 |
+
# Fix punctuation spacing
|
319 |
+
text = re.sub(r'\s+([.,!?।॥])', r'\1', text)
|
320 |
+
|
321 |
+
return text.strip()
|
322 |
|
323 |
class DocumentExporter:
|
324 |
"""Handles document export operations"""
|
|
|
331 |
buffer = io.BytesIO()
|
332 |
doc.save(buffer)
|
333 |
buffer.seek(0)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
334 |
return buffer
|
335 |
|
336 |
def main():
|
337 |
st.title("🌐 Enhanced Document Translation App")
|
338 |
|
|
|
|
|
|
|
|
|
|
|
339 |
# Display system info
|
340 |
st.sidebar.markdown(f"""
|
341 |
### System Information
|
342 |
**Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
|
343 |
+
**User:** {os.environ.get('USER', 'gauravchand')}
|
344 |
""")
|
345 |
|
346 |
+
# Check for HF_TOKEN
|
347 |
+
if not os.environ.get('HF_TOKEN'):
|
348 |
+
st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
|
349 |
+
st.stop()
|
350 |
+
|
351 |
# Load models
|
352 |
with st.spinner("Loading models... This may take a few minutes."):
|
353 |
try:
|
|
|
407 |
key="translation_result"
|
408 |
)
|
409 |
|
410 |
+
# Download option
|
411 |
+
st.markdown("### Download Option")
|
412 |
+
st.download_button(
|
413 |
+
label="Download as DOCX",
|
414 |
+
data=DocumentExporter.save_as_docx(final_text),
|
415 |
+
file_name="translated_document.docx",
|
416 |
+
mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
|
417 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
418 |
|
419 |
status_text.text("Translation completed successfully!")
|
420 |
progress_bar.progress(100)
|