Spaces:
Runtime error
Runtime error
""" | |
Allows to predict the summary for a given entry text | |
using LSTM model | |
""" | |
import pickle | |
import torch | |
from src import dataloader | |
from src.model import Decoder, Encoder, EncoderDecoderModel | |
# from transformers import AutoModel | |
with open("model/vocab.pkl", "rb") as vocab: | |
words = pickle.load(vocab) | |
vectoriser = dataloader.Vectoriser(words) | |
def inference_lstm(text: str) -> str: | |
""" | |
Predict the summary for an input text | |
-------- | |
Parameter | |
text: str | |
the text to sumarize | |
Return | |
str | |
The summary for the input text | |
""" | |
text = text.split() | |
# On défini les paramètres d'entrée pour le modèle | |
device = torch.device("cuda" if torch.cuda.is_available() else "cpu") | |
encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device) | |
encoder.to(device) | |
decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device) | |
decoder.to(device) | |
# On instancie le modèle | |
model = EncoderDecoderModel(encoder, decoder, vectoriser, device) | |
# model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM") | |
# model.load_state_dict(torch.load("model/model.pt", map_location=device)) | |
# model.eval() | |
# model.to(device) | |
# On vectorise le texte | |
source = vectoriser.encode(text) | |
source = source.to(device) | |
# On fait passer le texte dans le modèle | |
with torch.no_grad(): | |
output = model(source).to(device) | |
output.to(device) | |
output = output.argmax(dim=-1) | |
return vectoriser.decode(output) | |
# if __name__ == "__main__": | |
# # inference() | |
# print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys.")) | |