Spaces:
Runtime error
Runtime error
| import math | |
| from transformers import MBartTokenizerFast, MBartForConditionalGeneration | |
| import streamlit as st | |
| from datasets import load_from_disk | |
| from datasets.filesystems import S3FileSystem | |
| s3 = S3FileSystem(anon=True) | |
| def load_model(): | |
| print("Load correction model") | |
| return MBartForConditionalGeneration.from_pretrained("alice-hml/mBART_french_correction") | |
| def load_tokenizer(): | |
| print("Load tokenizer for correction model") | |
| return MBartTokenizerFast.from_pretrained("alice-hml/mBART_french_correction") | |
| model = load_model() | |
| tokenizer = load_tokenizer() | |
| def correct(sentence: str): | |
| tokenizer.src_lang = "fr_XX" | |
| encoded_orig = tokenizer(sentence, return_tensors="pt") | |
| generated_tokens = model.generate(**encoded_orig, | |
| forced_bos_token_id=tokenizer.lang_code_to_id["fr_XX"], | |
| max_length=math.ceil(len(encoded_orig.input_ids[0])*1.20), | |
| min_length=math.ceil(len(encoded_orig.input_ids[0])*0.8), | |
| num_beams=5, | |
| repetition_penalty=1.1, | |
| # max_time=5, | |
| ) | |
| return tokenizer.batch_decode(generated_tokens, skip_special_tokens=True)[0] |