""" Allows to predict the summary for a given entry text """ import torch import nltk import contractions import re import string nltk.download('stopwords') nltk.download('punkt') from transformers import AutoModelForSeq2SeqLM, AutoTokenizer def clean_data(texts): texts = texts.lower() texts = contractions.fix(texts) texts = texts.translate(str.maketrans("", "", string.punctuation)) texts = re.sub(r'\n',' ',texts) return texts def inferenceAPI_t5(text: str) -> str: """ Predict the summary for an input text -------- Parameter text: str the text to sumarize Return str The summary for the input text """ # definition des parametres d'entree pour le modèle text = clean_data(text) device = torch.device("cpu" if torch.cuda.is_available() else "cpu") tokenizer= (AutoTokenizer.from_pretrained("./summarization_t5")) # chargement du modele local model = (AutoModelForSeq2SeqLM .from_pretrained("./summarization_t5") .to(device)) text_encoding = tokenizer( text, max_length=1024, padding='max_length', truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors='pt' ) generated_ids = model.generate( input_ids=text_encoding['input_ids'], attention_mask=text_encoding['attention_mask'], max_length=128, num_beams=8, length_penalty=0.8, early_stopping=True ) preds = [ tokenizer.decode(gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True) for gen_id in generated_ids ] return "".join(preds) if __name__ == "__main__": text = input('Entrez votre phrase à résumer : ') print('summary:',inferenceAPI(text))