File size: 2,261 Bytes
ad78747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
"""
 Allows to predict the summary for a given entry text
"""
import torch
from nltk import word_tokenize

import dataloader
from model import Decoder, Encoder, EncoderDecoderModel

# On doit loader les données pour avoir le Vectoriser > sauvegarder "words" dans un fichiers et le loader par la suite ??
### À CHANGER POUR N'AVOIR À LOADER QUE LE VECTORISER
data1 = dataloader.Data("data/train_extract.jsonl")
data2 = dataloader.Data("data/dev_extract.jsonl")
train_dataset = data1.make_dataset()
dev_dataset = data2.make_dataset()
words = data1.get_words()

vectoriser = dataloader.Vectoriser(words)
word_counts = vectoriser.word_count


def inferenceAPI(text: str) -> str:
    """
    Predict the summary for an input text
    --------
    Parameter
        text: str
            the text to sumarize
    Return
        str
            The summary for the input text
    """
    text = word_tokenize(text)
    # On défini les paramètres d'entrée pour le modèle
    device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
    encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to(
        device
    )
    decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to(
        device
    )

    # On instancie le modèle
    model = EncoderDecoderModel(encoder, decoder, device)

    model.load_state_dict(torch.load("model/model.pt", map_location=device))
    model.eval()
    model.to(device)

    # On vectorise le texte
    source = vectoriser.encode(text)
    source = source.to(device)

    # On fait passer le texte dans le modèle
    with torch.no_grad():
        output = model(source).to(device)
        output.to(device)
    return vectoriser.decode(output)


# if __name__ == "__main__":
#     # inference()
#     print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))