gauravchand11 commited on
Commit
90c759f
·
verified ·
1 Parent(s): 5a89d4a

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +23 -29
app.py CHANGED
@@ -23,7 +23,7 @@ st.set_page_config(
23
  layout="wide"
24
  )
25
 
26
- # Display current information in sidebar with proper formatting
27
  current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
28
  st.sidebar.markdown("""
29
  ### System Information
@@ -76,8 +76,8 @@ def load_models():
76
  nllb_tokenizer = AutoTokenizer.from_pretrained(
77
  "facebook/nllb-200-distilled-600M",
78
  token=HF_TOKEN,
79
- trust_remote_code=True,
80
- use_fast=False # Use slow tokenizer to avoid warnings
81
  )
82
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
83
  "facebook/nllb-200-distilled-600M",
@@ -89,14 +89,12 @@ def load_models():
89
 
90
  # Load MT5 model for grammar correction
91
  mt5_tokenizer = AutoTokenizer.from_pretrained(
92
- "google/mt5-small",
93
  token=HF_TOKEN,
94
- trust_remote_code=True,
95
- legacy=False, # Use new behavior
96
- use_fast=False # Use slow tokenizer to avoid warnings
97
  )
98
  mt5_model = MT5ForConditionalGeneration.from_pretrained(
99
- "google/mt5-small",
100
  token=HF_TOKEN,
101
  torch_dtype=torch.float16,
102
  device_map="auto" if torch.cuda.is_available() else None,
@@ -177,9 +175,7 @@ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
177
  interpreted_batches = []
178
 
179
  for batch in batches:
180
- prompt = f"""Analyze the following text for context and cultural nuances,
181
- maintaining the core meaning while identifying any idiomatic expressions or
182
- cultural references: {batch}"""
183
 
184
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
185
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
@@ -194,6 +190,8 @@ def interpret_context(text: str, gemma_tuple: Tuple) -> str:
194
  )
195
 
196
  interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
 
197
  interpreted_batches.append(interpreted_text)
198
 
199
  return " ".join(interpreted_batches)
@@ -207,17 +205,12 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
207
  translated_batches = []
208
 
209
  for batch in batches:
210
- # Add source language token to input
211
- batch_with_lang = f"{source_lang} {batch}"
212
- inputs = tokenizer(batch_with_lang, return_tensors="pt", max_length=512, truncation=True)
213
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
214
 
215
- # Add target language token
216
- target_lang_token = tokenizer(target_lang, add_special_tokens=False)["input_ids"][0]
217
-
218
  outputs = model.generate(
219
  **inputs,
220
- forced_bos_token_id=target_lang_token,
221
  max_length=512,
222
  do_sample=True,
223
  temperature=0.7,
@@ -236,35 +229,36 @@ def correct_grammar(text: str, target_lang: str, mt5_tuple: Tuple) -> str:
236
  tokenizer, model = mt5_tuple
237
  lang_code = MT5_LANG_CODES[target_lang]
238
 
 
239
  prompts = {
240
- 'en': "grammar: ",
241
- 'hi': "व्याकरण सुधार: ",
242
- 'mr': "व्याकरण सुधारणा: "
243
  }
244
 
245
  batches = batch_process_text(text)
246
  corrected_batches = []
247
 
248
  for batch in batches:
249
- prompt = prompts[lang_code] + batch
250
- inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
 
251
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
252
 
253
  outputs = model.generate(
254
  **inputs,
255
  max_length=512,
256
  num_beams=5,
257
- do_sample=True,
258
- temperature=0.7,
259
- top_p=0.9,
260
- num_return_sequences=1
261
  )
262
 
263
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
 
264
  for prefix in prompts.values():
265
  corrected_text = corrected_text.replace(prefix, "")
266
- corrected_text = corrected_text.strip()
267
-
268
  corrected_batches.append(corrected_text)
269
 
270
  return " ".join(corrected_batches)
 
23
  layout="wide"
24
  )
25
 
26
+ # Display current information in sidebar
27
  current_time = datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')
28
  st.sidebar.markdown("""
29
  ### System Information
 
76
  nllb_tokenizer = AutoTokenizer.from_pretrained(
77
  "facebook/nllb-200-distilled-600M",
78
  token=HF_TOKEN,
79
+ src_lang="eng_Latn",
80
+ trust_remote_code=True
81
  )
82
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
83
  "facebook/nllb-200-distilled-600M",
 
89
 
90
  # Load MT5 model for grammar correction
91
  mt5_tokenizer = AutoTokenizer.from_pretrained(
92
+ "google/mt5-base", # Changed to base model for better performance
93
  token=HF_TOKEN,
94
+ trust_remote_code=True
 
 
95
  )
96
  mt5_model = MT5ForConditionalGeneration.from_pretrained(
97
+ "google/mt5-base", # Changed to base model for better performance
98
  token=HF_TOKEN,
99
  torch_dtype=torch.float16,
100
  device_map="auto" if torch.cuda.is_available() else None,
 
175
  interpreted_batches = []
176
 
177
  for batch in batches:
178
+ prompt = f"""Analyze and maintain the core meaning of this text: {batch}"""
 
 
179
 
180
  inputs = tokenizer(prompt, return_tensors="pt", max_length=512, truncation=True)
181
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
 
190
  )
191
 
192
  interpreted_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
193
+ # Remove the prompt from the output
194
+ interpreted_text = interpreted_text.replace(prompt, "").strip()
195
  interpreted_batches.append(interpreted_text)
196
 
197
  return " ".join(interpreted_batches)
 
205
  translated_batches = []
206
 
207
  for batch in batches:
208
+ inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
 
 
209
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
210
 
 
 
 
211
  outputs = model.generate(
212
  **inputs,
213
+ forced_bos_token_id=tokenizer.lang_code_to_id[target_lang],
214
  max_length=512,
215
  do_sample=True,
216
  temperature=0.7,
 
229
  tokenizer, model = mt5_tuple
230
  lang_code = MT5_LANG_CODES[target_lang]
231
 
232
+ # Language-specific prompts for grammar correction
233
  prompts = {
234
+ 'en': "Fix grammar: ",
235
+ 'hi': "व्याकरण: ",
236
+ 'mr': "व्याकरण: "
237
  }
238
 
239
  batches = batch_process_text(text)
240
  corrected_batches = []
241
 
242
  for batch in batches:
243
+ # Prepare input with target language prefix
244
+ input_text = f"{prompts[lang_code]}{batch}"
245
+ inputs = tokenizer(input_text, return_tensors="pt", max_length=512, truncation=True)
246
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
247
 
248
  outputs = model.generate(
249
  **inputs,
250
  max_length=512,
251
  num_beams=5,
252
+ length_penalty=1.0,
253
+ early_stopping=True,
254
+ do_sample=False # Disable sampling for more stable output
 
255
  )
256
 
257
  corrected_text = tokenizer.decode(outputs[0], skip_special_tokens=True)
258
+ # Clean up the output
259
  for prefix in prompts.values():
260
  corrected_text = corrected_text.replace(prefix, "")
261
+ corrected_text = corrected_text.replace("<extra_id_0>", "").replace("<extra_id_1>", "").strip()
 
262
  corrected_batches.append(corrected_text)
263
 
264
  return " ".join(corrected_batches)