gauravchand11 commited on
Commit
ed75acb
·
verified ·
1 Parent(s): f4c8d2f

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -95
app.py CHANGED
@@ -12,33 +12,41 @@ import os
12
  import re
13
  import torch
14
  import numpy as np
 
15
 
16
  # Load models and tokenizers
17
  @st.cache_resource
18
  def load_models():
19
- # BERT model for context understanding
20
- context_tokenizer = BertTokenizer.from_pretrained('google-bert/bert-base-multilingual-cased')
21
- context_model = BertModel.from_pretrained('google-bert/bert-base-multilingual-cased')
22
-
23
- # NLLB model for translation
24
- nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
25
- nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
26
-
27
- # GECToR model for grammar correction
28
- grammar_tokenizer = AutoTokenizer.from_pretrained('gotutiyan/gector-bert-base-cased-5k')
29
- grammar_model = AutoModelForTokenClassification.from_pretrained('gotutiyan/gector-bert-base-cased-5k')
30
-
31
- return {
32
- "context": (context_tokenizer, context_model),
33
- "nllb": (nllb_tokenizer, nllb_model),
34
- "grammar": (grammar_tokenizer, grammar_model)
35
- }
 
 
 
 
 
 
 
36
 
37
  def get_bert_embeddings(text, models):
38
  """Get contextual embeddings from BERT"""
39
  tokenizer, model = models["context"]
40
 
41
- # Split text into smaller chunks if needed
42
  max_length = 512
43
  chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)]
44
  contextual_embeddings = []
@@ -47,7 +55,7 @@ def get_bert_embeddings(text, models):
47
  inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
48
  with torch.no_grad():
49
  outputs = model(**inputs)
50
- embeddings = outputs.last_hidden_state.mean(dim=1) # Average pooling
51
  contextual_embeddings.append(embeddings)
52
 
53
  # Combine embeddings from all chunks
@@ -55,7 +63,7 @@ def get_bert_embeddings(text, models):
55
  return combined_embedding
56
 
57
  def apply_grammar_correction(text, models):
58
- """Apply grammar correction using GECToR"""
59
  tokenizer, model = models["grammar"]
60
 
61
  sentences = re.split('([.!?।]+)', text)
@@ -63,22 +71,23 @@ def apply_grammar_correction(text, models):
63
 
64
  for sentence in sentences:
65
  if sentence.strip():
 
66
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
67
  with torch.no_grad():
68
  outputs = model(**inputs)
69
  predictions = torch.argmax(outputs.logits, dim=2)
70
 
71
- # Convert predictions to corrected text
72
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
73
  corrected_tokens = []
74
 
75
  for token, pred in zip(tokens, predictions[0]):
76
- if pred == 0: # Keep token as is
77
- corrected_tokens.append(token)
78
- # Handle other prediction cases as needed
79
 
80
  corrected_text = tokenizer.convert_tokens_to_string(corrected_tokens)
81
- corrected_sentences.append(corrected_text)
 
82
 
83
  return " ".join(corrected_sentences)
84
 
@@ -109,7 +118,6 @@ def translate_text(text, src_lang, tgt_lang, models):
109
  if src_lang == tgt_lang:
110
  return text
111
 
112
- # Language codes for NLLB
113
  lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
114
 
115
  if src_lang not in lang_map or tgt_lang not in lang_map:
@@ -118,62 +126,67 @@ def translate_text(text, src_lang, tgt_lang, models):
118
  tgt_lang_code = lang_map[tgt_lang]
119
  tokenizer, model = models["nllb"]
120
 
121
- # Get contextual embeddings
122
- context_embedding = get_bert_embeddings(text, models)
123
-
124
- # Split into chunks for translation
125
- chunks = []
126
- current_chunk = ""
127
-
128
- for sentence in re.split('([.!?।]+)', text):
129
- if sentence.strip():
130
- if len(current_chunk) + len(sentence) < 450:
131
- current_chunk += sentence
132
- else:
133
- if current_chunk:
134
- chunks.append(current_chunk)
135
- current_chunk = sentence
136
-
137
- if current_chunk:
138
- chunks.append(current_chunk)
139
-
140
- translated_text = ""
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
141
 
142
- for chunk in chunks:
143
- if chunk.strip():
144
- # Prepare input with context
145
- inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
146
-
147
- # Add context embedding to attention mask
148
- attention_mask = inputs['attention_mask'].float()
149
- attention_mask = attention_mask * (1 + 0.1 * context_embedding.norm())
150
-
151
- # Get target language token ID
152
- tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang_code)
153
-
154
- # Generate translation
155
- with torch.no_grad():
156
- translated = model.generate(
157
- input_ids=inputs['input_ids'],
158
- attention_mask=attention_mask,
159
- forced_bos_token_id=tgt_lang_id,
160
- max_length=512,
161
- num_beams=5,
162
- length_penalty=1.0,
163
- no_repeat_ngram_size=3,
164
- do_sample=True,
165
- temperature=0.7
166
- )
167
- translated_chunk = tokenizer.decode(translated[0], skip_special_tokens=True)
168
- translated_text += translated_chunk + " "
169
-
170
- # Apply grammar correction
171
- corrected_text = apply_grammar_correction(translated_text.strip(), models)
172
-
173
- return corrected_text
174
 
175
  def save_text_to_file(text, original_filename, prefix="translated"):
176
- output_filename = f"{prefix}_{os.path.basename(original_filename)}.txt"
 
177
  with open(output_filename, "w", encoding="utf-8") as f:
178
  f.write(text)
179
  return output_filename
@@ -183,15 +196,14 @@ def process_document(file, source_lang, target_lang, models):
183
  # Extract text from uploaded file
184
  text = extract_text(file)
185
 
186
- # Log the input text for debugging
187
- st.sidebar.write("Input text:", text[:500] + "...")
 
 
188
 
189
- # Translate the text with context awareness and grammar correction
190
  translated_text = translate_text(text, source_lang, target_lang, models)
191
 
192
- # Log the output text for debugging
193
- st.sidebar.write("Output text:", translated_text[:500] + "...")
194
-
195
  # Save the result
196
  if translated_text.startswith("Error:"):
197
  output_file = save_text_to_file(translated_text, file.name, prefix="error")
@@ -199,9 +211,10 @@ def process_document(file, source_lang, target_lang, models):
199
  output_file = save_text_to_file(translated_text, file.name)
200
 
201
  return output_file, translated_text
 
202
  except Exception as e:
203
  error_message = f"Error: {str(e)}"
204
- st.error(f"An error occurred: {error_message}")
205
  output_file = save_text_to_file(error_message, file.name, prefix="error")
206
  return output_file, error_message
207
 
@@ -209,10 +222,15 @@ def main():
209
  st.title("Advanced Document Translator")
210
  st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages.")
211
 
 
 
 
 
212
  try:
213
- # Initialize models
214
  with st.spinner("Loading models..."):
215
  models = load_models()
 
216
 
217
  # File uploader
218
  uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
@@ -224,9 +242,6 @@ def main():
224
  with col2:
225
  target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
226
 
227
- # Add debug mode toggle
228
- debug_mode = st.sidebar.checkbox("Enable Debug Mode")
229
-
230
  if uploaded_file is not None and st.button("Translate"):
231
  with st.spinner("Processing and Translating..."):
232
  output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
@@ -242,13 +257,10 @@ def main():
242
  file_name=os.path.basename(output_file),
243
  mime="text/plain"
244
  )
245
-
246
- if debug_mode:
247
- st.sidebar.write("Translation completed")
248
- st.sidebar.write("Output file:", output_file)
249
 
250
  except Exception as e:
251
- st.error(f"An error occurred while initializing the application: {str(e)}")
 
252
 
253
  if __name__ == "__main__":
254
  main()
 
12
  import re
13
  import torch
14
  import numpy as np
15
+ from datetime import datetime, timezone
16
 
17
  # Load models and tokenizers
18
  @st.cache_resource
19
  def load_models():
20
+ try:
21
+ # BERT model for context understanding
22
+ context_tokenizer = BertTokenizer.from_pretrained('bert-base-multilingual-cased')
23
+ context_model = BertModel.from_pretrained('bert-base-multilingual-cased')
24
+
25
+ # NLLB model for translation
26
+ nllb_tokenizer = AutoTokenizer.from_pretrained("facebook/nllb-200-distilled-600M")
27
+ nllb_model = AutoModelForSeq2SeqLM.from_pretrained("facebook/nllb-200-distilled-600M")
28
+
29
+ # Grammar correction model
30
+ grammar_tokenizer = BertTokenizer.from_pretrained('bert-base-cased')
31
+ grammar_model = AutoModelForTokenClassification.from_pretrained(
32
+ 'bert-base-cased',
33
+ num_labels=3 # Assuming 3 labels: keep, delete, replace
34
+ )
35
+
36
+ return {
37
+ "context": (context_tokenizer, context_model),
38
+ "nllb": (nllb_tokenizer, nllb_model),
39
+ "grammar": (grammar_tokenizer, grammar_model)
40
+ }
41
+ except Exception as e:
42
+ st.error(f"Error loading models: {str(e)}")
43
+ raise e
44
 
45
  def get_bert_embeddings(text, models):
46
  """Get contextual embeddings from BERT"""
47
  tokenizer, model = models["context"]
48
 
49
+ # Split text into smaller chunks
50
  max_length = 512
51
  chunks = [text[i:i + max_length] for i in range(0, len(text), max_length)]
52
  contextual_embeddings = []
 
55
  inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
56
  with torch.no_grad():
57
  outputs = model(**inputs)
58
+ embeddings = outputs.last_hidden_state.mean(dim=1)
59
  contextual_embeddings.append(embeddings)
60
 
61
  # Combine embeddings from all chunks
 
63
  return combined_embedding
64
 
65
  def apply_grammar_correction(text, models):
66
+ """Basic grammar correction using BERT"""
67
  tokenizer, model = models["grammar"]
68
 
69
  sentences = re.split('([.!?।]+)', text)
 
71
 
72
  for sentence in sentences:
73
  if sentence.strip():
74
+ # Basic tokenization and prediction
75
  inputs = tokenizer(sentence, return_tensors="pt", padding=True, truncation=True, max_length=128)
76
  with torch.no_grad():
77
  outputs = model(**inputs)
78
  predictions = torch.argmax(outputs.logits, dim=2)
79
 
 
80
  tokens = tokenizer.convert_ids_to_tokens(inputs['input_ids'][0])
81
  corrected_tokens = []
82
 
83
  for token, pred in zip(tokens, predictions[0]):
84
+ if pred == 0 or token in ['[CLS]', '[SEP]', '[PAD]']:
85
+ if token not in ['[CLS]', '[SEP]', '[PAD]']:
86
+ corrected_tokens.append(token)
87
 
88
  corrected_text = tokenizer.convert_tokens_to_string(corrected_tokens)
89
+ if corrected_text.strip():
90
+ corrected_sentences.append(corrected_text)
91
 
92
  return " ".join(corrected_sentences)
93
 
 
118
  if src_lang == tgt_lang:
119
  return text
120
 
 
121
  lang_map = {"en": "eng_Latn", "hi": "hin_Deva", "mr": "mar_Deva"}
122
 
123
  if src_lang not in lang_map or tgt_lang not in lang_map:
 
126
  tgt_lang_code = lang_map[tgt_lang]
127
  tokenizer, model = models["nllb"]
128
 
129
+ try:
130
+ # Get contextual embeddings
131
+ context_embedding = get_bert_embeddings(text, models)
132
+
133
+ # Split into chunks for translation
134
+ chunks = []
135
+ current_chunk = ""
136
+
137
+ for sentence in re.split('([.!?।]+)', text):
138
+ if sentence.strip():
139
+ if len(current_chunk) + len(sentence) < 450:
140
+ current_chunk += sentence
141
+ else:
142
+ if current_chunk:
143
+ chunks.append(current_chunk)
144
+ current_chunk = sentence
145
+
146
+ if current_chunk:
147
+ chunks.append(current_chunk)
148
+
149
+ translated_text = ""
150
+
151
+ for chunk in chunks:
152
+ if chunk.strip():
153
+ inputs = tokenizer(chunk, return_tensors="pt", padding=True, truncation=True, max_length=512)
154
+
155
+ # Use context embedding to modify attention
156
+ attention_mask = inputs['attention_mask'].float()
157
+ context_weight = 0.1 * torch.sigmoid(context_embedding.mean())
158
+ attention_mask = attention_mask * (1 + context_weight)
159
+
160
+ # Get target language token ID
161
+ tgt_lang_id = tokenizer.convert_tokens_to_ids(tgt_lang_code)
162
+
163
+ with torch.no_grad():
164
+ translated = model.generate(
165
+ input_ids=inputs['input_ids'],
166
+ attention_mask=attention_mask,
167
+ forced_bos_token_id=tgt_lang_id,
168
+ max_length=512,
169
+ num_beams=5,
170
+ length_penalty=1.0,
171
+ no_repeat_ngram_size=3,
172
+ do_sample=True,
173
+ temperature=0.7
174
+ )
175
+ translated_chunk = tokenizer.decode(translated[0], skip_special_tokens=True)
176
+ translated_text += translated_chunk + " "
177
+
178
+ # Apply basic grammar correction
179
+ corrected_text = apply_grammar_correction(translated_text.strip(), models)
180
+
181
+ return corrected_text
182
 
183
+ except Exception as e:
184
+ st.error(f"Translation error: {str(e)}")
185
+ return f"Error during translation: {str(e)}"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
186
 
187
  def save_text_to_file(text, original_filename, prefix="translated"):
188
+ timestamp = datetime.now(timezone.utc).strftime("%Y%m%d_%H%M%S")
189
+ output_filename = f"{prefix}_{timestamp}_{os.path.basename(original_filename)}.txt"
190
  with open(output_filename, "w", encoding="utf-8") as f:
191
  f.write(text)
192
  return output_filename
 
196
  # Extract text from uploaded file
197
  text = extract_text(file)
198
 
199
+ # Add debugging information
200
+ st.sidebar.write("Processing document...")
201
+ st.sidebar.write(f"Source language: {source_lang}")
202
+ st.sidebar.write(f"Target language: {target_lang}")
203
 
204
+ # Translate the text
205
  translated_text = translate_text(text, source_lang, target_lang, models)
206
 
 
 
 
207
  # Save the result
208
  if translated_text.startswith("Error:"):
209
  output_file = save_text_to_file(translated_text, file.name, prefix="error")
 
211
  output_file = save_text_to_file(translated_text, file.name)
212
 
213
  return output_file, translated_text
214
+
215
  except Exception as e:
216
  error_message = f"Error: {str(e)}"
217
+ st.error(error_message)
218
  output_file = save_text_to_file(error_message, file.name, prefix="error")
219
  return output_file, error_message
220
 
 
222
  st.title("Advanced Document Translator")
223
  st.write("Upload a document (PDF, DOCX, or TXT) and select source and target languages.")
224
 
225
+ # Display current user and timestamp
226
+ st.sidebar.write(f"Current User: {os.getenv('USER', 'gauravchand')}")
227
+ st.sidebar.write(f"UTC Time: {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
228
+
229
  try:
230
+ # Initialize models with error handling
231
  with st.spinner("Loading models..."):
232
  models = load_models()
233
+ st.success("Models loaded successfully!")
234
 
235
  # File uploader
236
  uploaded_file = st.file_uploader("Upload Document", type=["pdf", "docx", "txt"])
 
242
  with col2:
243
  target_lang = st.selectbox("Target Language", ["en", "hi", "mr"], index=1)
244
 
 
 
 
245
  if uploaded_file is not None and st.button("Translate"):
246
  with st.spinner("Processing and Translating..."):
247
  output_file, result_text = process_document(uploaded_file, source_lang, target_lang, models)
 
257
  file_name=os.path.basename(output_file),
258
  mime="text/plain"
259
  )
 
 
 
 
260
 
261
  except Exception as e:
262
+ st.error(f"Application error: {str(e)}")
263
+ st.warning("Please try refreshing the page or contact support.")
264
 
265
  if __name__ == "__main__":
266
  main()