EveSa commited on
Commit
c8e0f9b
1 Parent(s): 3c03f61

refactoring de requirements.txt

Browse files
Files changed (1) hide show
  1. src/inference_lstm.py +59 -0
src/inference_lstm.py ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Allows to predict the summary for a given entry text
3
+ using LSTM model
4
+ """
5
+ import pickle
6
+
7
+ import torch
8
+
9
+ import dataloader
10
+ from model import Decoder, Encoder, EncoderDecoderModel
11
+ # from transformers import AutoModel
12
+
13
+ with open("model/vocab.pkl", "rb") as vocab:
14
+ words = pickle.load(vocab)
15
+ vectoriser = dataloader.Vectoriser(words)
16
+
17
+
18
+ def inferenceAPI(text: str) -> str:
19
+ """
20
+ Predict the summary for an input text
21
+ --------
22
+ Parameter
23
+ text: str
24
+ the text to sumarize
25
+ Return
26
+ str
27
+ The summary for the input text
28
+ """
29
+ text = text.split()
30
+ # On défini les paramètres d'entrée pour le modèle
31
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
32
+ encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
33
+ encoder.to(device)
34
+ decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
35
+ decoder.to(device)
36
+
37
+ # On instancie le modèle
38
+ model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
39
+ # model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM")
40
+
41
+ # model.load_state_dict(torch.load("model/model.pt", map_location=device))
42
+ # model.eval()
43
+ # model.to(device)
44
+
45
+ # On vectorise le texte
46
+ source = vectoriser.encode(text)
47
+ source = source.to(device)
48
+
49
+ # On fait passer le texte dans le modèle
50
+ with torch.no_grad():
51
+ output = model(source).to(device)
52
+ output.to(device)
53
+ output = output.argmax(dim=-1)
54
+ return vectoriser.decode(output)
55
+
56
+
57
+ # if __name__ == "__main__":
58
+ # # inference()
59
+ # print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))