AreesaAshfaq commited on
Commit
1ff3b59
·
verified ·
1 Parent(s): a27cb28

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +16 -4
app.py CHANGED
@@ -76,22 +76,34 @@ LANGUAGE_MODELS = {
76
  'Zulu': 'zu',
77
  }
78
 
 
79
  @st.cache_resource
80
  def load_model():
81
  tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
82
  model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
83
  return tokenizer, model
84
 
85
- def translate(text, target_language):
86
  tokenizer, model = load_model()
87
 
88
  # Set the target language code for translation
89
  target_lang_code = LANGUAGE_MODELS.get(target_language)
90
 
91
  tokenizer.src_lang = "en"
92
- encoded_input = tokenizer(text, return_tensors="pt")
93
- generated_tokens = model.generate(**encoded_input, forced_bos_token_id=tokenizer.get_lang_id(target_lang_code))
94
- translation = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
 
 
 
 
 
 
 
 
 
 
 
95
  return translation
96
 
97
  st.title('Welcome to PolyTranslate')
 
76
  'Zulu': 'zu',
77
  }
78
 
79
+
80
  @st.cache_resource
81
  def load_model():
82
  tokenizer = M2M100Tokenizer.from_pretrained("facebook/m2m100_418M")
83
  model = M2M100ForConditionalGeneration.from_pretrained("facebook/m2m100_418M")
84
  return tokenizer, model
85
 
86
+ def translate(text, target_language, max_chunk_size=500):
87
  tokenizer, model = load_model()
88
 
89
  # Set the target language code for translation
90
  target_lang_code = LANGUAGE_MODELS.get(target_language)
91
 
92
  tokenizer.src_lang = "en"
93
+
94
+ # Split the text into chunks if it's too long
95
+ tokens = tokenizer.encode(text, return_tensors="pt")
96
+ input_ids = tokens[0]
97
+ translations = []
98
+
99
+ for i in range(0, len(input_ids), max_chunk_size):
100
+ chunk = input_ids[i:i + max_chunk_size].unsqueeze(0)
101
+ generated_tokens = model.generate(**{'input_ids': chunk, 'forced_bos_token_id': tokenizer.get_lang_id(target_lang_code)})
102
+ chunk_translation = tokenizer.decode(generated_tokens[0], skip_special_tokens=True)
103
+ translations.append(chunk_translation)
104
+
105
+ # Combine all translated chunks
106
+ translation = " ".join(translations)
107
  return translation
108
 
109
  st.title('Welcome to PolyTranslate')