""" Allows to predict the summary for a given entry text """ import torch import re import string from transformers import AutoModelForSeq2SeqLM, AutoTokenizer def clean_text(texts: str) -> str: texts = texts.lower() texts = texts.translate(str.maketrans("", "", string.punctuation)) texts = re.sub(r'\n', ' ', texts) return texts def inferenceAPI_T5(text: str) -> str: """ Predict the summary for an input text -------- Parameter text: str the text to sumarize Return str The summary for the input text """ # On défini les paramètres d'entrée pour le modèle text = clean_text(text) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") tokenizer = (AutoTokenizer.from_pretrained("Linggg/t5_summary",use_auth_token=True)) # load local model model = (AutoModelForSeq2SeqLM .from_pretrained("Linggg/t5_summary",use_auth_token=True) .to(device)) text_encoding = tokenizer( text, max_length=1024, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) generated_ids = model.generate( input_ids=text_encoding['input_ids'], attention_mask=text_encoding['attention_mask'], max_length=128, num_beams=8, length_penalty=0.8, early_stopping=True ) preds = [ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generated_ids ] return "".join(preds) # if __name__ == "__main__": # text = input('Entrez votre phrase à résumer : ') # print('summary:', inferenceAPI_T5(text))