gauravchand11 commited on
Commit
c124a1a
·
verified ·
1 Parent(s): 1337d1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +114 -394
app.py CHANGED
@@ -1,426 +1,146 @@
 
1
  import streamlit as st
2
- import PyPDF2
3
  import docx
4
- import io
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, AutoModelForSeq2SeqLM, MT5ForConditionalGeneration
6
- import torch
7
- from pathlib import Path
8
- from typing import Union, Tuple, List, Dict
9
  import os
10
- import sys
11
- from datetime import datetime, timezone
12
- import warnings
13
  import re
14
 
15
- # Filter warnings
16
- warnings.filterwarnings('ignore', category=UserWarning)
 
 
 
 
 
17
 
18
- # Page config
19
- st.set_page_config(
20
- page_title="Enhanced Document Translation App",
21
- page_icon="🌐",
22
- layout="wide"
23
- )
24
 
25
- # Constants and Configurations
26
- CONFIG = {
27
- "MAX_BATCH_LENGTH": 512,
28
- "MIN_BATCH_LENGTH": 50,
29
- "TRANSLATION_TEMPERATURE": 0.7,
30
- "CONTEXT_TEMPERATURE": 0.3,
31
- "NUM_BEAMS": 5,
32
- "SUPPORTED_LANGUAGES": {
33
- 'English': 'eng_Latn',
34
- 'Hindi': 'hin_Deva',
35
- 'Marathi': 'mar_Deva'
36
- },
37
- "MT5_LANG_CODES": {
38
- 'eng_Latn': 'en',
39
- 'hin_Deva': 'hi',
40
- 'mar_Deva': 'mr'
41
- },
42
- "GRAMMAR_PROMPTS": {
43
- 'en': "Fix grammar and improve fluency: ",
44
- 'hi': "व्याकरण और प्रवाह सुधारें: ",
45
- 'mr': "व्याकरण आणि प्रवाह सुधारा: "
46
- }
47
- }
48
-
49
- class DocumentProcessor:
50
- """Handles document processing and text extraction"""
51
-
52
- @staticmethod
53
- def extract_text_from_file(uploaded_file) -> str:
54
- file_extension = Path(uploaded_file.name).suffix.lower()
55
-
56
- extractors = {
57
- '.pdf': DocumentProcessor._extract_from_pdf,
58
- '.docx': DocumentProcessor._extract_from_docx,
59
- '.txt': lambda f: f.getvalue().decode('utf-8')
60
- }
61
-
62
- if file_extension not in extractors:
63
- raise ValueError(f"Unsupported file format: {file_extension}")
64
-
65
- return extractors[file_extension](uploaded_file)
66
-
67
- @staticmethod
68
- def _extract_from_pdf(file) -> str:
69
- pdf_reader = PyPDF2.PdfReader(file)
70
- return "\n".join(page.extract_text() for page in pdf_reader.pages).strip()
71
-
72
- @staticmethod
73
- def _extract_from_docx(file) -> str:
74
  doc = docx.Document(file)
75
- return "\n".join(paragraph.text for paragraph in doc.paragraphs).strip()
76
-
77
- class TextBatcher:
78
- """Handles text batching with improved sentence boundary detection"""
79
 
80
- @staticmethod
81
- def batch_process_text(text: str, max_length: int = CONFIG["MAX_BATCH_LENGTH"]) -> List[str]:
82
- sentences = TextBatcher._split_into_sentences(text)
83
- batches = []
84
- current_batch = []
85
- current_length = 0
86
-
87
- for sentence in sentences:
88
- sentence_length = len(sentence)
89
-
90
- if current_length + sentence_length > max_length:
91
- if current_batch:
92
- batches.append(" ".join(current_batch))
93
- current_batch = [sentence]
94
- current_length = sentence_length
95
- else:
96
- current_batch.append(sentence)
97
- current_length += sentence_length
98
-
99
- if current_batch:
100
- batches.append(" ".join(current_batch))
101
-
102
- return batches
103
 
104
- @staticmethod
105
- def _split_into_sentences(text: str) -> List[str]:
106
- """Split text into sentences with improved boundary detection"""
107
- delimiters = ['. ', '! ', '? ', '।', '॥', '\n']
108
- sentences = []
109
- current = text
110
-
111
- for delimiter in delimiters:
112
- parts = current.split(delimiter)
113
- current = parts[0]
114
- for part in parts[1:]:
115
- if len(current.strip()) > 0:
116
- sentences.append(current.strip() + delimiter.strip())
117
- current = part
118
-
119
- if len(current.strip()) > 0:
120
- sentences.append(current.strip())
121
-
122
- return sentences
123
 
124
- class ModelManager:
125
- """Manages loading and caching of AI models"""
126
-
127
- @st.cache_resource
128
- def load_models():
129
- try:
130
- device = "cuda" if torch.cuda.is_available() else "cpu"
131
-
132
- models = {
133
- "gemma": ModelManager._load_gemma_model(),
134
- "nllb": ModelManager._load_nllb_model(),
135
- "mt5": ModelManager._load_mt5_model()
136
- }
137
-
138
- if not torch.cuda.is_available():
139
- for model_tuple in models.values():
140
- model_tuple[1].to(device)
141
-
142
- return models
143
-
144
- except Exception as e:
145
- st.error(f"Error loading models: {str(e)}")
146
- st.error(f"Python version: {sys.version}")
147
- st.error(f"PyTorch version: {torch.__version__}")
148
- raise e
149
-
150
- @staticmethod
151
- def _load_gemma_model():
152
- tokenizer = AutoTokenizer.from_pretrained(
153
- "google/gemma-2b",
154
- token=os.environ.get('HF_TOKEN'),
155
- trust_remote_code=True
156
- )
157
- model = AutoModelForCausalLM.from_pretrained(
158
- "google/gemma-2b",
159
- token=os.environ.get('HF_TOKEN'),
160
- torch_dtype=torch.float16,
161
- device_map="auto" if torch.cuda.is_available() else None,
162
- trust_remote_code=True
163
- )
164
- return (tokenizer, model)
165
-
166
- @staticmethod
167
- def _load_nllb_model():
168
- tokenizer = AutoTokenizer.from_pretrained(
169
- "facebook/nllb-200-distilled-600M",
170
- token=os.environ.get('HF_TOKEN'),
171
- use_fast=False,
172
- trust_remote_code=True
173
- )
174
- model = AutoModelForSeq2SeqLM.from_pretrained(
175
- "facebook/nllb-200-distilled-600M",
176
- token=os.environ.get('HF_TOKEN'),
177
- torch_dtype=torch.float16,
178
- device_map="auto" if torch.cuda.is_available() else None,
179
- trust_remote_code=True
180
- )
181
- return (tokenizer, model)
182
-
183
- @staticmethod
184
- def _load_mt5_model():
185
- tokenizer = AutoTokenizer.from_pretrained(
186
- "google/mt5-base",
187
- token=os.environ.get('HF_TOKEN'),
188
- trust_remote_code=True
189
- )
190
- model = MT5ForConditionalGeneration.from_pretrained(
191
- "google/mt5-base",
192
- token=os.environ.get('HF_TOKEN'),
193
- torch_dtype=torch.float16,
194
- device_map="auto" if torch.cuda.is_available() else None,
195
- trust_remote_code=True
196
- )
197
- return (tokenizer, model)
198
 
199
- class TranslationPipeline:
200
- """Manages the translation pipeline with context understanding"""
201
-
202
- def __init__(self, models: Dict):
203
- self.models = models
204
-
205
- @torch.no_grad()
206
- def process_text(self, text: str, source_lang: str, target_lang: str) -> str:
207
- batches = TextBatcher.batch_process_text(text)
208
- final_results = []
209
-
210
- for batch in batches:
211
- # Step 1: Context Understanding
212
- context = self._understand_context(batch)
213
-
214
- # Step 2: Context-aware Translation
215
- translated = self._translate_with_context(
216
- context,
217
- source_lang,
218
- target_lang
219
- )
220
-
221
- # Step 3: Grammar Correction
222
- corrected = self._correct_grammar(
223
- translated,
224
- target_lang
225
  )
226
-
227
- final_results.append(corrected)
228
-
229
- # Clean up the final text
230
- final_text = " ".join(final_results)
231
- return self._clean_text(final_text)
232
 
233
- def _understand_context(self, text: str) -> str:
234
- tokenizer, model = self.models["gemma"]
235
-
236
- prompt = f"""Analyze and provide context for translation:
237
- Text: {text}
238
- Key points to consider:
239
- - Main topic and subject matter
240
- - Cultural context and nuances
241
- - Technical terminology if any
242
- - Tone and style of writing
243
 
244
- Provide a clear and concise interpretation that maintains:
245
- 1. Original meaning
246
- 2. Cultural context
247
- 3. Technical accuracy
248
- 4. Tone and style"""
249
-
250
- inputs = tokenizer(prompt, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
251
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
252
-
253
- outputs = model.generate(
254
- **inputs,
255
- max_length=CONFIG["MAX_BATCH_LENGTH"],
256
- do_sample=True,
257
- temperature=CONFIG["CONTEXT_TEMPERATURE"],
258
- pad_token_id=tokenizer.eos_token_id,
259
- num_return_sequences=1
260
- )
261
-
262
- context = tokenizer.decode(outputs[0], skip_special_tokens=True)
263
- return context.replace(prompt, "").strip()
264
-
265
- def _translate_with_context(self, text: str, source_lang: str, target_lang: str) -> str:
266
- tokenizer, model = self.models["nllb"]
267
-
268
- target_lang_token = f"___{target_lang}___"
269
- inputs = tokenizer(text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
270
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
271
-
272
- target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
273
-
274
- outputs = model.generate(
275
- **inputs,
276
- forced_bos_token_id=target_lang_id,
277
- max_length=CONFIG["MAX_BATCH_LENGTH"],
278
- do_sample=True,
279
- temperature=CONFIG["TRANSLATION_TEMPERATURE"],
280
- num_beams=CONFIG["NUM_BEAMS"],
281
- num_return_sequences=1,
282
- length_penalty=1.0,
283
- repetition_penalty=1.2
284
- )
285
-
286
- return tokenizer.decode(outputs[0], skip_special_tokens=True)
287
-
288
- def _correct_grammar(self, text: str, target_lang: str) -> str:
289
- tokenizer, model = self.models["mt5"]
290
- lang_code = CONFIG["MT5_LANG_CODES"][target_lang]
291
- prompt = CONFIG["GRAMMAR_PROMPTS"][lang_code]
292
-
293
- input_text = f"{prompt}{text}"
294
- inputs = tokenizer(input_text, return_tensors="pt", max_length=CONFIG["MAX_BATCH_LENGTH"], truncation=True)
295
- inputs = {k: v.to(model.device) for k, v in inputs.items()}
296
-
297
- outputs = model.generate(
298
- **inputs,
299
- max_length=CONFIG["MAX_BATCH_LENGTH"],
300
- num_beams=CONFIG["NUM_BEAMS"],
301
- length_penalty=1.0,
302
- early_stopping=True,
303
- no_repeat_ngram_size=2,
304
- do_sample=False
305
- )
306
-
307
- corrected = tokenizer.decode(outputs[0], skip_special_tokens=True)
308
- return self._clean_text(corrected.replace(prompt, "").strip())
309
-
310
- def _clean_text(self, text: str) -> str:
311
- """Clean up the text by removing special tokens and fixing formatting"""
312
- # Remove MT5 special tokens
313
- text = re.sub(r'<extra_id_\d+>', '', text)
314
-
315
- # Fix multiple spaces
316
- text = re.sub(r'\s+', ' ', text)
317
-
318
- # Fix punctuation spacing
319
- text = re.sub(r'\s+([.,!?।॥])', r'\1', text)
320
-
321
- return text.strip()
322
 
323
- class DocumentExporter:
324
- """Handles document export operations"""
325
-
326
- @staticmethod
327
- def save_as_docx(text: str) -> io.BytesIO:
328
- doc = docx.Document()
329
- doc.add_paragraph(text)
330
-
331
- buffer = io.BytesIO()
332
- doc.save(buffer)
333
- buffer.seek(0)
334
- return buffer
 
 
 
 
 
 
 
 
 
335
 
 
336
  def main():
337
- st.title("🌐 Enhanced Document Translation App")
338
-
339
- # Display system info
340
- st.sidebar.markdown(f"""
341
- ### System Information
342
- **Current UTC Time:** {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}
343
- **User:** {os.environ.get('USER', 'gauravchand')}
344
- """)
345
 
346
- # Check for HF_TOKEN
347
- if not os.environ.get('HF_TOKEN'):
348
- st.error("HF_TOKEN not found in environment variables. Please add it in the Spaces settings.")
349
- st.stop()
350
 
351
- # Load models
352
- with st.spinner("Loading models... This may take a few minutes."):
353
- try:
354
- models = ModelManager.load_models()
355
- pipeline = TranslationPipeline(models)
356
- except Exception as e:
357
- st.error(f"Error initializing translation pipeline: {str(e)}")
358
- return
359
-
360
- # File upload
361
- uploaded_file = st.file_uploader(
362
- "Upload your document (PDF, DOCX, or TXT)",
363
- type=['pdf', 'docx', 'txt']
364
- )
365
 
366
  # Language selection
367
  col1, col2 = st.columns(2)
368
  with col1:
369
- source_language = st.selectbox(
370
- "Source Language",
371
- options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
372
- index=0
373
- )
374
-
375
  with col2:
376
- target_language = st.selectbox(
377
- "Target Language",
378
- options=list(CONFIG["SUPPORTED_LANGUAGES"].keys()),
379
- index=1
380
- )
381
 
382
- if uploaded_file and st.button("Translate", type="primary"):
383
- try:
384
- progress_bar = st.progress(0)
385
- status_text = st.empty()
386
-
387
- # Process document
388
- status_text.text("Extracting text from document...")
389
- text = DocumentProcessor.extract_text_from_file(uploaded_file)
390
- progress_bar.progress(20)
391
-
392
- # Perform translation
393
- status_text.text("Translating document with context understanding...")
394
- final_text = pipeline.process_text(
395
- text,
396
- CONFIG["SUPPORTED_LANGUAGES"][source_language],
397
- CONFIG["SUPPORTED_LANGUAGES"][target_language]
398
- )
399
- progress_bar.progress(90)
400
 
401
  # Display result
402
- st.markdown("### Translation Result")
403
- st.text_area(
404
- label="Translated Text",
405
- value=final_text,
406
- height=200,
407
- key="translation_result"
408
- )
409
-
410
- # Download option
411
- st.markdown("### Download Option")
412
- st.download_button(
413
- label="Download as DOCX",
414
- data=DocumentExporter.save_as_docx(final_text),
415
- file_name="translated_document.docx",
416
- mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
417
- )
418
-
419
- status_text.text("Translation completed successfully!")
420
- progress_bar.progress(100)
421
 
422
- except Exception as e:
423
- st.error(f"An error occurred: {str(e)}")
 
 
 
 
 
 
424
 
425
  if __name__ == "__main__":
426
  main()
 
1
+ from transformers import AutoTokenizer, AutoModelForSeq2SeqLM
2
  import streamlit as st
3
+ from PyPDF2 import PdfReader
4
  import docx
 
 
 
 
 
5
  import os
 
 
 
6
  import re
7
 
8
+ # Load NLLB model and tokenizer
9
+ @st.cache_resource
10
+ def load_translation_model():
11
+ model_name = "facebook/nllb-200-distilled-600M"
12
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
13
+ model = AutoModelForSeq2SeqLM.from_pretrained(model_name)
14
+ return tokenizer, model
15
 
16
+ # Initialize model
17
+ @st.cache_resource
18
+ def initialize_models():
19
+ tokenizer, model = load_translation_model()
20
+ return {"nllb": (tokenizer, model)}
 
21
 
22
+ # Function to extract text from different file types
23
+ def extract_text(file):
24
+ ext = os.path.splitext(file.name)[1].lower()
25
+
26
+ if ext == ".pdf":
27
+ reader = PdfReader(file)
28
+ text = ""
29
+ for page in reader.pages:
30
+ text += page.extract_text() + "\n"
31
+ return text
32
+
33
+ elif ext == ".docx":
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
34
  doc = docx.Document(file)
35
+ text = ""
36
+ for para in doc.paragraphs:
37
+ text += para.text + "\n"
38
+ return text
39
 
40
+ elif ext == ".txt":
41
+ return file.read().decode("utf-8")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
42
 
43
+ else:
44
+ raise ValueError("Unsupported file format. Please upload PDF, DOCX, or TXT files.")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
45
 
46
+ # Translation function
47
+ def translate_text(text, src_lang, tgt_lang, models):
48
+ if src_lang == tgt_lang:
49
+ return text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
50
 
51
+ # Language codes for NLLB
52
+ lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
53
+
54
+ if src_lang not in lang_map or tgt_lang not in lang_map:
55
+ return "Error: Unsupported language combination"
56
+
57
+ tgt_lang_code = lang_map[tgt_lang]
58
+
59
+ tokenizer, model = models["nllb"]
60
+
61
+ # Preprocess for idioms
62
+ preprocessed_text = preprocess_idioms(text, src_lang, tgt_lang)
63
+
64
+ # Split text into manageable chunks
65
+ sentences = preprocessed_text.split("\n")
66
+ translated_text = ""
67
+
68
+ for sentence in sentences:
69
+ if sentence.strip():
70
+ inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=512)
71
+ # Use lang_code_to_id instead of get_lang_id
72
+ translated = model.generate(
73
+ **inputs,
74
+ forced_bos_token_id=tokenizer.lang_code_to_id[tgt_lang_code],
75
+ max_length=512
 
76
  )
77
+ translated_sentence = tokenizer.decode(translated[0], skip_special_tokens=True)
78
+ translated_text += translated_sentence + "\n"
 
 
 
 
79
 
80
+ return translated_text
 
 
 
 
 
 
 
 
 
81
 
82
+ # Function to save text as a file
83
+ def save_text_to_file(text, original_filename, prefix="translated"):
84
+ output_filename = f"{prefix}_{os.path.basename(original_filename)}.txt"
85
+ with open(output_filename, "w", encoding="utf-8") as f:
86
+ f.write(text)
87
+ return output_filename
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
88
 
89
+ # Main processing function
90
+ def process_document(file, source_lang, target_lang, models):
91
+ try:
92
+ # Extract text from uploaded file
93
+ text = extract_text(file)
94
+
95
+ # Translate the text
96
+ translated_text = translate_text(text, source_lang, target_lang, models)
97
+
98
+ # Save the result (success or error) to a file
99
+ if translated_text.startswith("Error:"):
100
+ output_file = save_text_to_file(translated_text, file.name, prefix="error")
101
+ else:
102
+ output_file = save_text_to_file(translated_text, file.name)
103
+
104
+ return output_file, translated_text
105
+ except Exception as e:
106
+ # Save error message to a file
107
+ error_message = f"Error: {str(e)}"
108
+ output_file = save_text_to_file(error_message, file.name, prefix="error")
109
+ return output_file, error_message
110
 
111
+ # Streamlit interface
112
  def main():
113
+ st.title("Document Translator (NLLB-200)")
114
+ st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages (English, Hindi, Marathi).")
 
 
 
 
 
 
115
 
116
+ # Initialize models
117
+ models = initialize_models()
 
 
118
 
119
+ # File uploader
120
+ uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  # Language selection
123
  col1, col2 = st.columns(2)
124
  with col1:
125
+ source_lang = st.selectbox("Source Language", ["en", "hi", "mr"], index=0)
 
 
 
 
 
126
  with col2:
127
+ target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
 
 
 
 
128
 
129
+ if uploaded_file is not None and st.button("Translate"):
130
+ with st.spinner("Translating..."):
131
+ output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
132
 
133
  # Display result
134
+ st.text_area("Translated Text", result_text, height=300)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
135
 
136
+ # Provide download button
137
+ with open(output_file, "rb") as file:
138
+ st.download_button(
139
+ label="Download Translated Document",
140
+ data=file,
141
+ file_name=os.path.basename(output_file),
142
+ mime="text/plain"
143
+ )
144
 
145
  if __name__ == "__main__":
146
  main()