Spaces:
Runtime error
Runtime error
""" | |
Allows to predict the summary for a given entry text | |
""" | |
import re | |
import string | |
import os | |
os.environ['TRANSFORMERS_CACHE'] = './.cache' | |
import contractions | |
import torch | |
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | |
def clean_text(texts: str) -> str: | |
texts = texts.lower() | |
texts = texts.translate(str.maketrans("", "", string.punctuation)) | |
texts = re.sub(r"\n", " ", texts) | |
return texts | |
def inference_t5(text: str) -> str: | |
""" | |
Predict the summary for an input text | |
-------- | |
Parameter | |
text: str | |
the text to sumarize | |
Return | |
str | |
The summary for the input text | |
""" | |
# On défini les paramètres d'entrée pour le modèle | |
text = clean_text(text) | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
hf_token = "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" | |
tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary", use_auth_token=hf_token ) | |
# load local model | |
model = (AutoModelForSeq2SeqLM | |
.from_pretrained("Linggg/t5_summary", use_auth_token = hf_token ) | |
.to(device)) | |
text_encoding = tokenizer( | |
text, | |
max_length=1024, | |
padding="max_length", | |
truncation=True, | |
return_attention_mask=True, | |
add_special_tokens=True, | |
return_tensors="pt", | |
) | |
generated_ids = model.generate( | |
input_ids=text_encoding["input_ids"], | |
attention_mask=text_encoding["attention_mask"], | |
max_length=128, | |
num_beams=8, | |
length_penalty=0.8, | |
early_stopping=True, | |
) | |
preds = [ | |
tokenizer.decode( | |
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True | |
) | |
for gen_id in generated_ids | |
] | |
return "".join(preds) | |
# if __name__ == "__main__": | |
# text = input('Entrez votre phrase à résumer : ') | |
# print('summary:', inferenceAPI_T5(text)) | |