""" Allows to predict the summary for a given entry text """ import re import string import os os.environ['TRANSFORMERS_CACHE'] = './.cache' import contractions import torch from transformers import AutoModelForSeq2SeqLM, AutoTokenizer def clean_text(texts: str) -> str: texts = texts.lower() texts = texts.translate(str.maketrans("", "", string.punctuation)) texts = re.sub(r"\n", " ", texts) return texts def inference_t5(text: str) -> str: """ Predict the summary for an input text -------- Parameter text: str the text to sumarize Return str The summary for the input text """ # On défini les paramètres d'entrée pour le modèle text = clean_text(text) device = torch.device("cuda" if torch.cuda.is_available() else "cpu") hf_token = "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" tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary", use_auth_token=hf_token ) # load local model model = (AutoModelForSeq2SeqLM .from_pretrained("Linggg/t5_summary", use_auth_token = hf_token ) .to(device)) text_encoding = tokenizer( text, max_length=1024, padding="max_length", truncation=True, return_attention_mask=True, add_special_tokens=True, return_tensors="pt", ) generated_ids = model.generate( input_ids=text_encoding["input_ids"], attention_mask=text_encoding["attention_mask"], max_length=128, num_beams=8, length_penalty=0.8, early_stopping=True, ) preds = [ tokenizer.decode( gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True ) for gen_id in generated_ids ] return "".join(preds) # if __name__ == "__main__": # text = input('Entrez votre phrase à résumer : ') # print('summary:', inferenceAPI_T5(text))