File size: 1,952 Bytes
41508f8
 
 
 
 
23e12c6
 
3c03f61
 
 
41508f8
 
5925e5f
 
41508f8
 
3c03f61
41508f8
 
5925e5f
4e410f4
9cd8995
41508f8
 
 
 
 
 
 
 
 
 
4874293
41508f8
5925e5f
41508f8
cd518e1
 
41508f8
 
cd518e1
5925e5f
4874293
6158825
41508f8
 
 
3c03f61
41508f8
 
 
3c03f61
41508f8
 
3c03f61
 
41508f8
 
 
3c03f61
41508f8
 
 
3c03f61
 
 
5925e5f
41508f8
 
 
5925e5f
4874293
 
1aab2b0
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
"""
 Allows to predict the summary for a given entry text
"""
import re
import string
import os
os.environ['TRANSFORMERS_CACHE'] = './.cache'

import contractions
import torch
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer


def clean_text(texts: str) -> str:
    texts = texts.lower()
    texts = texts.translate(str.maketrans("", "", string.punctuation))
    texts = re.sub(r"\n", " ", texts)
    return texts



def inference_t5(text: str) -> str:
    """
    Predict the summary for an input text
    --------
    Parameter
        text: str
            the text to sumarize
    Return
        str
            The summary for the input text
    """

    # On défini les paramètres d'entrée pour le modèle
    text = clean_text(text)
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    hf_token = "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"
    tokenizer = AutoTokenizer.from_pretrained("Linggg/t5_summary", use_auth_token=hf_token )
    # load local model
    model = (AutoModelForSeq2SeqLM
             .from_pretrained("Linggg/t5_summary", use_auth_token = hf_token )
             .to(device))


    text_encoding = tokenizer(
        text,
        max_length=1024,
        padding="max_length",
        truncation=True,
        return_attention_mask=True,
        add_special_tokens=True,
        return_tensors="pt",
    )
    generated_ids = model.generate(
        input_ids=text_encoding["input_ids"],
        attention_mask=text_encoding["attention_mask"],
        max_length=128,
        num_beams=8,
        length_penalty=0.8,
        early_stopping=True,
    )

    preds = [
        tokenizer.decode(
            gen_id, skip_special_tokens=True, clean_up_tokenization_spaces=True
        )
        for gen_id in generated_ids
    ]
    return "".join(preds)


# if __name__ == "__main__":
#     text = input('Entrez votre phrase à résumer : ')
#     print('summary:', inferenceAPI_T5(text))