import math from transformers import MBartTokenizerFast, MBartForConditionalGeneration import streamlit as st from datasets import load_from_disk from datasets.filesystems import S3FileSystem s3 = S3FileSystem(anon=True) @st.cache(allow_output_mutation=True) def load_model(): print("Load correction model") return MBartForConditionalGeneration.from_pretrained("aligator/mBART_french_correction") @st.cache(allow_output_mutation=True) def load_tokenizer(): print("Load tokenizer for correction model") return MBartTokenizerFast.from_pretrained("aligator/mBART_french_correction") model = load_model() tokenizer = load_tokenizer() def correct(sentence: str): tokenizer.src_lang = "fr_XX" encoded_orig = tokenizer(sentence, return_tensors="pt") generated_tokens = model.generate(**encoded_orig, forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"], max_length=math.ceil(len(encoded_orig.input_ids[0])*1.20), min_length=math.ceil(len(encoded_orig.input_ids[0])*0.8), num_beams=5, repetition_penalty=1.1, # max_time=5, ) return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]