Spaces:
Runtime error
Runtime error
import math | |
from transformers import MBartTokenizerFast, MBartForConditionalGeneration | |
import streamlit as st | |
from datasets import load_from_disk | |
from datasets.filesystems import S3FileSystem | |
s3 = S3FileSystem(anon=True) | |
def load_model(): | |
print("Load correction model") | |
return MBartForConditionalGeneration.from_pretrained("aligator/mBART_french_correction") | |
def load_tokenizer(): | |
print("Load tokenizer for correction model") | |
return MBartTokenizerFast.from_pretrained("aligator/mBART_french_correction") | |
model = load_model() | |
tokenizer = load_tokenizer() | |
def correct(sentence: str): | |
tokenizer.src_lang = "fr_XX" | |
encoded_orig = tokenizer(sentence, return_tensors="pt") | |
generated_tokens = model.generate(**encoded_orig, | |
forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"], | |
max_length=math.ceil(len(encoded_orig.input_ids[0])*1.20), | |
min_length=math.ceil(len(encoded_orig.input_ids[0])*0.8), | |
num_beams=5, | |
repetition_penalty=1.1, | |
# max_time=5, | |
) | |
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |