File size: 1,402 Bytes
2329469
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
import math
from transformers import MBartTokenizerFast, MBartForConditionalGeneration
import streamlit as st
from datasets import load_from_disk
from datasets.filesystems import S3FileSystem

s3 = S3FileSystem(anon=True)  

@st.cache(allow_output_mutation=True)
def load_model():
    print("Load correction model")
    return MBartForConditionalGeneration.from_pretrained("aligator/mBART_french_correction")


@st.cache(allow_output_mutation=True)
def load_tokenizer():
    print("Load tokenizer for correction model")
    return MBartTokenizerFast.from_pretrained("aligator/mBART_french_correction")


model = load_model()
tokenizer = load_tokenizer()

def correct(sentence: str):
    tokenizer.src_lang = "fr_XX"
    encoded_orig = tokenizer(sentence, return_tensors="pt")
    generated_tokens = model.generate(**encoded_orig,
                                      forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"],
                                      max_length=math.ceil(len(encoded_orig.input_ids[0])*1.20),
                                      min_length=math.ceil(len(encoded_orig.input_ids[0])*0.8),
                                      num_beams=5,
                                      repetition_penalty=1.1,
                                    #   max_time=5,
                                      )
    return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0]