gauravchand11 commited on
Commit
1337d1b
·
verified ·
1 Parent(s): 5e3207d

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +32 -48
app.py CHANGED
@@ -5,13 +5,12 @@ import io
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
6
  import torch
7
  from pathlib import Path
8
- import tempfile
9
  from typing import Union, Tuple, List, Dict
10
  import os
11
  import sys
12
  from datetime import datetime, timezone
13
  import warnings
14
- import json
15
 
16
  # Filter warnings
17
  warnings.filterwarnings('ignore', category=UserWarning)
@@ -105,7 +104,6 @@ class TextBatcher:
105
  @staticmethod
106
  def _split_into_sentences(text: str) -> List[str]:
107
  """Split text into sentences with improved boundary detection"""
108
- # Basic sentence boundary detection
109
  delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
110
  sentences = []
111
  current = text
@@ -131,14 +129,12 @@ class ModelManager:
131
  try:
132
  device = "cuda" if torch.cuda.is_available() else "cpu"
133
 
134
- # Load models with improved error handling
135
  models = {
136
  "gemma": ModelManager._load_gemma_model(),
137
  "nllb": ModelManager._load_nllb_model(),
138
  "mt5": ModelManager._load_mt5_model()
139
  }
140
 
141
- # Move models to appropriate device
142
  if not torch.cuda.is_available():
143
  for model_tuple in models.values():
144
  model_tuple[1].to(device)
@@ -208,7 +204,6 @@ class TranslationPipeline:
208
 
209
  @torch.no_grad()
210
  def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
211
- # Split text into manageable batches
212
  batches = TextBatcher.batch_process_text(text)
213
  final_results = []
214
 
@@ -231,10 +226,11 @@ class TranslationPipeline:
231
 
232
  final_results.append(corrected)
233
 
234
- return " ".join(final_results)
 
 
235
 
236
  def _understand_context(self, text: str) -> str:
237
- """Enhanced context understanding using Gemma model"""
238
  tokenizer, model = self.models["gemma"]
239
 
240
  prompt = f"""Analyze and provide context for translation:
@@ -267,12 +263,9 @@ Provide a clear and concise interpretation that maintains:
267
  return context.replace(prompt, "").strip()
268
 
269
  def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
270
- """Enhanced translation using NLLB model with context awareness"""
271
  tokenizer, model = self.models["nllb"]
272
 
273
- source_lang_token = f"___{source_lang}___"
274
  target_lang_token = f"___{target_lang}___"
275
-
276
  inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
277
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
278
 
@@ -293,7 +286,6 @@ Provide a clear and concise interpretation that maintains:
293
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
294
 
295
  def _correct_grammar(self, text: str, target_lang: str) -> str:
296
- """Enhanced grammar correction using MT5 model"""
297
  tokenizer, model = self.models["mt5"]
298
  lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
299
  prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
@@ -313,9 +305,20 @@ Provide a clear and concise interpretation that maintains:
313
  )
314
 
315
  corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
316
- for prefix in CONFIG["GRAMMAR_PROMPTS"].values():
317
- corrected = corrected.replace(prefix, "")
318
- return corrected.strip()
 
 
 
 
 
 
 
 
 
 
 
319
 
320
  class DocumentExporter:
321
  """Handles document export operations"""
@@ -328,31 +331,23 @@ class DocumentExporter:
328
  buffer = io.BytesIO()
329
  doc.save(buffer)
330
  buffer.seek(0)
331
-
332
- return buffer
333
-
334
- @staticmethod
335
- def save_as_text(text: str) -> io.BytesIO:
336
- buffer = io.BytesIO()
337
- buffer.write(text.encode())
338
- buffer.seek(0)
339
  return buffer
340
 
341
  def main():
342
  st.title("🌐 Enhanced Document Translation App")
343
 
344
- # Check for HF_TOKEN
345
- if not os.environ.get('HF_TOKEN'):
346
- st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
347
- st.stop()
348
-
349
  # Display system info
350
  st.sidebar.markdown(f"""
351
  ### System Information
352
  **Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
353
- **User:** {os.environ.get('USER', 'unknown')}
354
  """)
355
 
 
 
 
 
 
356
  # Load models
357
  with st.spinner("Loading models... This may take a few minutes."):
358
  try:
@@ -412,25 +407,14 @@ def main():
412
  key="translation_result"
413
  )
414
 
415
- # Download options
416
- st.markdown("### Download Options")
417
- col1, col2 = st.columns(2)
418
-
419
- with col1:
420
- st.download_button(
421
- label="Download as TXT",
422
- data=DocumentExporter.save_as_text(final_text),
423
- file_name="translated_document.txt",
424
- mime="text/plain"
425
- )
426
-
427
- with col2:
428
- st.download_button(
429
- label="Download as DOCX",
430
- data=DocumentExporter.save_as_docx(final_text),
431
- file_name="translated_document.docx",
432
- mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
433
- )
434
 
435
  status_text.text("Translation completed successfully!")
436
  progress_bar.progress(100)
 
5
  from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
6
  import torch
7
  from pathlib import Path
 
8
  from typing import Union, Tuple, List, Dict
9
  import os
10
  import sys
11
  from datetime import datetime, timezone
12
  import warnings
13
+ import re
14
 
15
  # Filter warnings
16
  warnings.filterwarnings('ignore', category=UserWarning)
 
104
  @staticmethod
105
  def _split_into_sentences(text: str) -> List[str]:
106
  """Split text into sentences with improved boundary detection"""
 
107
  delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
108
  sentences = []
109
  current = text
 
129
  try:
130
  device = "cuda" if torch.cuda.is_available() else "cpu"
131
 
 
132
  models = {
133
  "gemma": ModelManager._load_gemma_model(),
134
  "nllb": ModelManager._load_nllb_model(),
135
  "mt5": ModelManager._load_mt5_model()
136
  }
137
 
 
138
  if not torch.cuda.is_available():
139
  for model_tuple in models.values():
140
  model_tuple[1].to(device)
 
204
 
205
  @torch.no_grad()
206
  def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
 
207
  batches = TextBatcher.batch_process_text(text)
208
  final_results = []
209
 
 
226
 
227
  final_results.append(corrected)
228
 
229
+ # Clean up the final text
230
+ final_text = " ".join(final_results)
231
+ return self._clean_text(final_text)
232
 
233
  def _understand_context(self, text: str) -> str:
 
234
  tokenizer, model = self.models["gemma"]
235
 
236
  prompt = f"""Analyze and provide context for translation:
 
263
  return context.replace(prompt, "").strip()
264
 
265
  def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
 
266
  tokenizer, model = self.models["nllb"]
267
 
 
268
  target_lang_token = f"___{target_lang}___"
 
269
  inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
270
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
271
 
 
286
  return tokenizer.decode(outputs[0], skip_special_tokens=True)
287
 
288
  def _correct_grammar(self, text: str, target_lang: str) -> str:
 
289
  tokenizer, model = self.models["mt5"]
290
  lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
291
  prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
 
305
  )
306
 
307
  corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
308
+ return self._clean_text(corrected.replace(prompt, "").strip())
309
+
310
+ def _clean_text(self, text: str) -> str:
311
+ """Clean up the text by removing special tokens and fixing formatting"""
312
+ # Remove MT5 special tokens
313
+ text = re.sub(r'<extra_id_\d+>', '', text)
314
+
315
+ # Fix multiple spaces
316
+ text = re.sub(r'\s+', ' ', text)
317
+
318
+ # Fix punctuation spacing
319
+ text = re.sub(r'\s+([.,!?।॥])', r'\1', text)
320
+
321
+ return text.strip()
322
 
323
  class DocumentExporter:
324
  """Handles document export operations"""
 
331
  buffer = io.BytesIO()
332
  doc.save(buffer)
333
  buffer.seek(0)
 
 
 
 
 
 
 
 
334
  return buffer
335
 
336
  def main():
337
  st.title("🌐 Enhanced Document Translation App")
338
 
 
 
 
 
 
339
  # Display system info
340
  st.sidebar.markdown(f"""
341
  ### System Information
342
  **Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
343
+ **User:** {os.environ.get('USER', 'gauravchand')}
344
  """)
345
 
346
+ # Check for HF_TOKEN
347
+ if not os.environ.get('HF_TOKEN'):
348
+ st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
349
+ st.stop()
350
+
351
  # Load models
352
  with st.spinner("Loading models... This may take a few minutes."):
353
  try:
 
407
  key="translation_result"
408
  )
409
 
410
+ # Download option
411
+ st.markdown("### Download Option")
412
+ st.download_button(
413
+ label="Download as DOCX",
414
+ data=DocumentExporter.save_as_docx(final_text),
415
+ file_name="translated_document.docx",
416
+ mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
417
+ )
 
 
 
 
 
 
 
 
 
 
 
418
 
419
  status_text.text("Translation completed successfully!")
420
  progress_bar.progress(100)