SummaryProject / src /inference_t5.py
EveSa's picture
fix api problem and tokent auth
cd518e1
raw
history blame
1.95 kB
"""
Allows to predict the summary for a given entry text
"""
import re
import string
import os
os.environ['TRANSFORMERS_CACHE'] = './.cache'
import contractions
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
def clean_text(texts: str) -> str:
texts = texts.lower()
texts = texts.translate(str.maketrans("", "", string.punctuation))
texts = re.sub(r"\n", " ", texts)
return texts
def inference_t5(text: str) -> str:
"""
Predict the summary for an input text
--------
Parameter
text: str
the text to sumarize
Return
str
The summary for the input text
"""
# On défini les paramètres d'entrée pour le modèle
text = clean_text(text)
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
hf_token = "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"
tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary", use_auth_token=hf_token )
# load local model
model = (AutoModelForSeq2SeqLM
.from_pretrained("Linggg/t5_summary", use_auth_token = hf_token )
.to(device))
text_encoding = tokenizer(
text,
max_length=1024,
padding="max_length",
truncation=True,
return_attention_mask=True,
add_special_tokens=True,
return_tensors="pt",
)
generated_ids = model.generate(
input_ids=text_encoding["input_ids"],
attention_mask=text_encoding["attention_mask"],
max_length=128,
num_beams=8,
length_penalty=0.8,
early_stopping=True,
)
preds = [
tokenizer.decode(
gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
)
for gen_id in generated_ids
]
return "".join(preds)
# if __name__ == "__main__":
# text = input('Entrez votre phrase à résumer : ')
# print('summary:', inferenceAPI_T5(text))