File size: 1,788 Bytes
41508f8
 
 
 
 
3c03f61
 
 
41508f8
 
6158825
 
41508f8
 
 
3c03f61
41508f8
 
6158825
 
41508f8
 
 
 
 
 
 
 
 
 
6158825
 
 
 
3c03f61
6158825
3c03f61
6158825
41508f8
 
 
3c03f61
41508f8
 
 
3c03f61
41508f8
 
3c03f61
 
41508f8
 
 
3c03f61
41508f8
 
 
3c03f61
 
 
6158825
41508f8
 
 
6158825
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
"""
 Allows to predict the summary for a given entry text
"""
import re
import string

import contractions
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


def clean_text(texts: str) -> str:
    texts = texts.lower()
    texts = contractions.fix(texts)
    texts = texts.translate(str.maketrans("", "", string.punctuation))
    texts = re.sub(r"\n", " ", texts)
    return texts


def inferenceAPI(text: str) -> str:
    """
    Predict the summary for an input text
    --------
    Parameter
        text: str
            the text to sumarize
    Return
        str
            The summary for the input text
    """

    # On défini les paramètres d'entrée pour le modèle
    text = clean_text(text)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary")
    # load local model
    model = AutoModelForSeq2SeqLM.from_pretrained("Linggg/t5_summary").to(device)

    text_encoding = tokenizer(
        text,
        max_length=1024,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    generated_ids = model.generate(
        input_ids=text_encoding["input_ids"],
        attention_mask=text_encoding["attention_mask"],
        max_length=128,
        num_beams=8,
        length_penalty=0.8,
        early_stopping=True,
    )

    preds = [
        tokenizer.decode(
            gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        for gen_id in generated_ids
    ]
    return "".join(preds)


# if __name__ == "__main__":
#     text = input('Entrez votre phrase à résumer : ')
#     print('summary:', inferenceAPI(text))