gauravchand11 commited on
Commit
5a89d4a
·
verified ·
1 Parent(s): f88b938

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +82 -59
app.py CHANGED
@@ -10,11 +10,26 @@ from typing import Union, Tuple
10
  import os
11
  import sys
12
  from datetime import datetime, timezone
 
13
 
14
- # Display current information in sidebar
15
- st.sidebar.text(f"Current Date and Time (UTC):")
16
- st.sidebar.text(datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S'))
17
- st.sidebar.text(f"Current User's Login: {os.environ.get('USER', 'gauravchand')}")
 
 
 
 
 
 
 
 
 
 
 
 
 
 
18
 
19
  # Get Hugging Face token from environment variables
20
  HF_TOKEN = os.environ.get('HF_TOKEN')
@@ -61,7 +76,8 @@ def load_models():
61
  nllb_tokenizer = AutoTokenizer.from_pretrained(
62
  "facebook/nllb-200-distilled-600M",
63
  token=HF_TOKEN,
64
- trust_remote_code=True
 
65
  )
66
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
67
  "facebook/nllb-200-distilled-600M",
@@ -75,7 +91,9 @@ def load_models():
75
  mt5_tokenizer = AutoTokenizer.from_pretrained(
76
  "google/mt5-small",
77
  token=HF_TOKEN,
78
- trust_remote_code=True
 
 
79
  )
80
  mt5_model = MT5ForConditionalGeneration.from_pretrained(
81
  "google/mt5-small",
@@ -155,7 +173,6 @@ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
155
  """Use Gemma model to interpret context and understand regional nuances."""
156
  tokenizer, model = gemma_tuple
157
 
158
- # Split text into batches
159
  batches = batch_process_text(text)
160
  interpreted_batches = []
161
 
@@ -186,23 +203,21 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
186
  """Translate text using NLLB model."""
187
  tokenizer, model = nllb_tuple
188
 
189
- # Split text into batches
190
  batches = batch_process_text(text)
191
  translated_batches = []
192
 
193
  for batch in batches:
194
- # Prepare the input text with source language token
195
- inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
 
196
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
197
 
198
- # Get target language token ID
199
- target_lang_token = f"___{target_lang}___"
200
- target_lang_id = tokenizer.convert_tokens_to_ids(target_lang_token)
201
 
202
- # Generate translation
203
  outputs = model.generate(
204
  **inputs,
205
- forced_bos_token_id=target_lang_id,
206
  max_length=512,
207
  do_sample=True,
208
  temperature=0.7,
@@ -217,21 +232,16 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
217
 
218
  @torch.no_grad()
219
  def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
220
- """
221
- Correct grammar using MT5 model for all supported languages.
222
- Uses a text-to-text approach with language-specific prompts.
223
- """
224
  tokenizer, model = mt5_tuple
225
  lang_code = MT5_LANG_CODES[target_lang]
226
 
227
- # Language-specific prompts for grammar correction
228
  prompts = {
229
  'en': "grammar: ",
230
  'hi': "व्याकरण सुधार: ",
231
  'mr': "व्याकरण सुधारणा: "
232
  }
233
 
234
- # Split text into batches
235
  batches = batch_process_text(text)
236
  corrected_batches = []
237
 
@@ -251,8 +261,6 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
251
  )
252
 
253
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
254
-
255
- # Clean up any artifacts from the model output
256
  for prefix in prompts.values():
257
  corrected_text = corrected_text.replace(prefix, "")
258
  corrected_text = corrected_text.strip()
@@ -273,7 +281,7 @@ def save_as_docx(text: str) -> io.BytesIO:
273
  return docx_buffer
274
 
275
  def main():
276
- st.title("Document Translation App")
277
 
278
  # Load models
279
  with st.spinner("Loading models... This may take a few minutes."):
@@ -306,40 +314,52 @@ def main():
306
  index=1
307
  )
308
 
309
- if uploaded_file and st.button("Translate"):
310
  try:
311
- with st.spinner("Processing document..."):
312
- # Extract text
313
- text = extract_text_from_file(uploaded_file)
314
-
315
- # Interpret context
316
- with st.spinner("Interpreting context..."):
317
- interpreted_text = interpret_context(text, gemma_tuple)
318
-
319
- # Translate
320
- with st.spinner("Translating..."):
321
- translated_text = translate_text(
322
- interpreted_text,
323
- SUPPORTED_LANGUAGES[source_language],
324
- SUPPORTED_LANGUAGES[target_language],
325
- nllb_tuple
326
- )
327
-
328
- # Grammar correction
329
- with st.spinner("Correcting grammar..."):
330
- corrected_text = correct_grammar(
331
- translated_text,
332
- SUPPORTED_LANGUAGES[target_language],
333
- mt5_tuple
334
- )
335
-
336
- # Display result
337
- st.subheader("Translation Result:")
338
- st.text_area("Translated Text:", value=corrected_text, height=150)
339
-
340
- # Download options
341
- st.subheader("Download Translation:")
342
-
 
 
 
 
 
 
 
 
 
 
 
 
343
  # Text file download
344
  text_buffer = io.BytesIO()
345
  text_buffer.write(corrected_text.encode())
@@ -351,7 +371,8 @@ def main():
351
  file_name="translated_document.txt",
352
  mime="text/plain"
353
  )
354
-
 
355
  # DOCX file download
356
  docx_buffer = save_as_docx(corrected_text)
357
  st.download_button(
@@ -360,7 +381,9 @@ def main():
360
  file_name="translated_document.docx",
361
  mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
362
  )
363
-
 
 
364
  except Exception as e:
365
  st.error(f"An error occurred: {str(e)}")
366
 
 
10
  import os
11
  import sys
12
  from datetime import datetime, timezone
13
+ import warnings
14
 
15
+ # Filter out specific warnings
16
+ warnings.filterwarnings('ignore', category=UserWarning, module='transformers.convert_slow_tokenizer')
17
+ warnings.filterwarnings('ignore', category=UserWarning, module='transformers.tokenization_utils_base')
18
+
19
+ # Custom styling
20
+ st.set_page_config(
21
+ page_title="Document Translation App",
22
+ page_icon="🌐",
23
+ layout="wide"
24
+ )
25
+
26
+ # Display current information in sidebar with proper formatting
27
+ current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
28
+ st.sidebar.markdown("""
29
+ ### System Information
30
+ **Current UTC Time:** {}
31
+ **User:** {}
32
+ """.format(current_time, os.environ.get('USER', 'gauravchand')))
33
 
34
  # Get Hugging Face token from environment variables
35
  HF_TOKEN = os.environ.get('HF_TOKEN')
 
76
  nllb_tokenizer = AutoTokenizer.from_pretrained(
77
  "facebook/nllb-200-distilled-600M",
78
  token=HF_TOKEN,
79
+ trust_remote_code=True,
80
+ use_fast=False # Use slow tokenizer to avoid warnings
81
  )
82
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
83
  "facebook/nllb-200-distilled-600M",
 
91
  mt5_tokenizer = AutoTokenizer.from_pretrained(
92
  "google/mt5-small",
93
  token=HF_TOKEN,
94
+ trust_remote_code=True,
95
+ legacy=False, # Use new behavior
96
+ use_fast=False # Use slow tokenizer to avoid warnings
97
  )
98
  mt5_model = MT5ForConditionalGeneration.from_pretrained(
99
  "google/mt5-small",
 
173
  """Use Gemma model to interpret context and understand regional nuances."""
174
  tokenizer, model = gemma_tuple
175
 
 
176
  batches = batch_process_text(text)
177
  interpreted_batches = []
178
 
 
203
  """Translate text using NLLB model."""
204
  tokenizer, model = nllb_tuple
205
 
 
206
  batches = batch_process_text(text)
207
  translated_batches = []
208
 
209
  for batch in batches:
210
+ # Add source language token to input
211
+ batch_with_lang = f"{source_lang} {batch}"
212
+ inputs = tokenizer(batch_with_lang, return_tensors="pt", max_length=512, truncation=True)
213
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
214
 
215
+ # Add target language token
216
+ target_lang_token = tokenizer(target_lang, add_special_tokens=False)["input_ids"][0]
 
217
 
 
218
  outputs = model.generate(
219
  **inputs,
220
+ forced_bos_token_id=target_lang_token,
221
  max_length=512,
222
  do_sample=True,
223
  temperature=0.7,
 
232
 
233
  @torch.no_grad()
234
  def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
235
+ """Correct grammar using MT5 model for all supported languages."""
 
 
 
236
  tokenizer, model = mt5_tuple
237
  lang_code = MT5_LANG_CODES[target_lang]
238
 
 
239
  prompts = {
240
  'en': "grammar: ",
241
  'hi': "व्याकरण सुधार: ",
242
  'mr': "व्याकरण सुधारणा: "
243
  }
244
 
 
245
  batches = batch_process_text(text)
246
  corrected_batches = []
247
 
 
261
  )
262
 
263
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
264
  for prefix in prompts.values():
265
  corrected_text = corrected_text.replace(prefix, "")
266
  corrected_text = corrected_text.strip()
 
281
  return docx_buffer
282
 
283
  def main():
284
+ st.title("🌐 Document Translation App")
285
 
286
  # Load models
287
  with st.spinner("Loading models... This may take a few minutes."):
 
314
  index=1
315
  )
316
 
317
+ if uploaded_file and st.button("Translate", type="primary"):
318
  try:
319
+ progress_bar = st.progress(0)
320
+
321
+ # Extract text
322
+ text = extract_text_from_file(uploaded_file)
323
+ progress_bar.progress(20)
324
+
325
+ # Interpret context
326
+ with st.spinner("Interpreting context..."):
327
+ interpreted_text = interpret_context(text, gemma_tuple)
328
+ progress_bar.progress(40)
329
+
330
+ # Translate
331
+ with st.spinner("Translating..."):
332
+ translated_text = translate_text(
333
+ interpreted_text,
334
+ SUPPORTED_LANGUAGES[source_language],
335
+ SUPPORTED_LANGUAGES[target_language],
336
+ nllb_tuple
337
+ )
338
+ progress_bar.progress(70)
339
+
340
+ # Grammar correction
341
+ with st.spinner("Correcting grammar..."):
342
+ corrected_text = correct_grammar(
343
+ translated_text,
344
+ SUPPORTED_LANGUAGES[target_language],
345
+ mt5_tuple
346
+ )
347
+ progress_bar.progress(90)
348
+
349
+ # Display result
350
+ st.markdown("### Translation Result")
351
+ st.text_area(
352
+ label="Translated Text",
353
+ value=corrected_text,
354
+ height=200,
355
+ key="translation_result"
356
+ )
357
+
358
+ # Download options
359
+ st.markdown("### Download Options")
360
+ col1, col2 = st.columns(2)
361
+
362
+ with col1:
363
  # Text file download
364
  text_buffer = io.BytesIO()
365
  text_buffer.write(corrected_text.encode())
 
371
  file_name="translated_document.txt",
372
  mime="text/plain"
373
  )
374
+
375
+ with col2:
376
  # DOCX file download
377
  docx_buffer = save_as_docx(corrected_text)
378
  st.download_button(
 
381
  file_name="translated_document.docx",
382
  mime="application/vnd.openxmlformats-officedocument.wordprocessingml.document"
383
  )
384
+
385
+ progress_bar.progress(100)
386
+
387
  except Exception as e:
388
  st.error(f"An error occurred: {str(e)}")
389