Update app.py
Browse files
app.py
CHANGED
@@ -289,7 +289,7 @@ model.eval()
|
|
289 |
bot_name = "WeASK"
|
290 |
|
291 |
###removed
|
292 |
-
from transformers import MBartForConditionalGeneration,
|
293 |
|
294 |
#def download_model():
|
295 |
|
@@ -299,11 +299,15 @@ from transformers import MBartForConditionalGeneration, MBart50TokenizerFast
|
|
299 |
|
300 |
|
301 |
################################
|
302 |
-
def
|
303 |
model_name = "facebook/mbart-large-50-many-to-many-mmt"
|
304 |
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
305 |
-
tokenizer =
|
|
|
|
|
|
|
306 |
|
|
|
307 |
model_inputs = tokenizer(input_text, return_tensors="pt")
|
308 |
generated_tokens = model.generate(**model_inputs,forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
|
309 |
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|
|
|
289 |
bot_name = "WeASK"
|
290 |
|
291 |
###removed
|
292 |
+
from transformers import MBartForConditionalGeneration, MBart50Tokenizer
|
293 |
|
294 |
#def download_model():
|
295 |
|
|
|
299 |
|
300 |
|
301 |
################################
|
302 |
+
def download_model():
|
303 |
model_name = "facebook/mbart-large-50-many-to-many-mmt"
|
304 |
model = MBartForConditionalGeneration.from_pretrained(model_name)
|
305 |
+
tokenizer = MBart50Tokenizer.from_pretrained(model_name)
|
306 |
+
return model, tokenizer
|
307 |
+
|
308 |
+
model, tokenizer = download_model()
|
309 |
|
310 |
+
def get_response(input_text):
|
311 |
model_inputs = tokenizer(input_text, return_tensors="pt")
|
312 |
generated_tokens = model.generate(**model_inputs,forced_bos_token_id=tokenizer.lang_code_to_id["en_XX"])
|
313 |
translation = tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)
|