Spaces:
Runtime error
Runtime error
File size: 2,090 Bytes
c8e0f9b 1f86975 c8e0f9b 9cd8995 c8e0f9b |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 |
"""
Allows to predict the summary for a given entry text
using LSTM model
"""
import pickle
import torch
from src import dataloader
from src.model import Decoder, Encoder, EncoderDecoderModel
# from transformers import AutoModel
with open("model/vocab.pkl", "rb") as vocab:
words = pickle.load(vocab)
vectoriser = dataloader.Vectoriser(words)
def inference_lstm(text: str) -> str:
"""
Predict the summary for an input text
--------
Parameter
text: str
the text to sumarize
Return
str
The summary for the input text
"""
text = text.split()
# On défini les paramètres d'entrée pour le modèle
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
encoder.to(device)
decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device)
decoder.to(device)
# On instancie le modèle
model = EncoderDecoderModel(encoder, decoder, vectoriser, device)
# model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM")
# model.load_state_dict(torch.load("model/model.pt", map_location=device))
# model.eval()
# model.to(device)
# On vectorise le texte
source = vectoriser.encode(text)
source = source.to(device)
# On fait passer le texte dans le modèle
with torch.no_grad():
output = model(source).to(device)
output.to(device)
output = output.argmax(dim=-1)
return vectoriser.decode(output)
# if __name__ == "__main__":
# # inference()
# print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))
|