french_gec / corrector.py
alice-hml's picture
Upload corrector.py
9d78076
raw
history blame
No virus
1.4 kB
import math
from transformers import MBartTokenizerFast, MBartForConditionalGeneration
import streamlit as st
from datasets import load_from_disk
from datasets.filesystems import S3FileSystem
s3 = S3FileSystem(anon=True)
@st.cache(allow_output_mutation=True)
def load_model():
print("Load correction model")
return MBartForConditionalGeneration.from_pretrained("alice-hml/mBART_french_correction")
@st.cache(allow_output_mutation=True)
def load_tokenizer():
print("Load tokenizer for correction model")
return MBartTokenizerFast.from_pretrained("alice-hml/mBART_french_correction")
model = load_model()
tokenizer = load_tokenizer()
def correct(sentence: str):
tokenizer.src_lang = "fr_XX"
encoded_orig = tokenizer(sentence, return_tensors="pt")
generated_tokens = model.generate(**encoded_orig,
forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"],
max_length=math.ceil(len(encoded_orig.input_ids[0])*1.20),
min_length=math.ceil(len(encoded_orig.input_ids[0])*0.8),
num_beams=5,
repetition_penalty=1.1,
# max_time=5,
)
return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]