gauravchand11 commited on
Commit
23bd434
·
verified ·
1 Parent(s): 8662527

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +8 -3
app.py CHANGED
@@ -9,6 +9,7 @@ import tempfile
9
  from typing import Union, Tuple
10
  import os
11
  from datetime import datetime, timezone
 
12
 
13
  # Display current information
14
  st.sidebar.text(f"Current Time (UTC): {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
@@ -59,6 +60,7 @@ def load_models():
59
  nllb_tokenizer = AutoTokenizer.from_pretrained(
60
  "facebook/nllb-200-distilled-600M",
61
  token=HF_TOKEN,
 
62
  trust_remote_code=True
63
  )
64
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
@@ -189,14 +191,17 @@ def translate_text(text: str, source_lang: str, target_lang: str, nllb_tuple: Tu
189
  translated_batches = []
190
 
191
  for batch in batches:
 
 
 
 
192
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
193
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
194
 
195
- forced_bos_token_id = tokenizer.lang_code_to_id[target_lang]
196
-
197
  outputs = model.generate(
198
  **inputs,
199
- forced_bos_token_id=forced_bos_token_id,
200
  max_length=512,
201
  temperature=0.7,
202
  num_beams=5,
 
9
  from typing import Union, Tuple
10
  import os
11
  from datetime import datetime, timezone
12
+ import sys
13
 
14
  # Display current information
15
  st.sidebar.text(f"Current Time (UTC): {datetime.now(timezone.utc).strftime('%Y-%m-%d %H:%M:%S')}")
 
60
  nllb_tokenizer = AutoTokenizer.from_pretrained(
61
  "facebook/nllb-200-distilled-600M",
62
  token=HF_TOKEN,
63
+ src_lang="eng_Latn", # Default source language
64
  trust_remote_code=True
65
  )
66
  nllb_model = AutoModelForSeq2SeqLM.from_pretrained(
 
191
  translated_batches = []
192
 
193
  for batch in batches:
194
+ # Set the source language for the tokenizer
195
+ tokenizer.src_lang = source_lang
196
+
197
+ # Prepare the input text
198
  inputs = tokenizer(batch, return_tensors="pt", max_length=512, truncation=True)
199
  inputs = {k: v.to(model.device) for k, v in inputs.items()}
200
 
201
+ # Generate translation with forced target language
 
202
  outputs = model.generate(
203
  **inputs,
204
+ forced_bos_token_id=tokenizer.get_lang_id(target_lang),
205
  max_length=512,
206
  temperature=0.7,
207
  num_beams=5,