""" Allows to predict the summary for a given entry text using LSTM model """ import pickle import torch from src import dataloader from src.model import Decoder, Encoder, EncoderDecoderModel # from transformers import AutoModel with open("model/vocab.pkl", "rb") as vocab: words = pickle.load(vocab) vectoriser = dataloader.Vectoriser(words) def inference_lstm(text: str) -> str: """ Predict the summary for an input text -------- Parameter text: str the text to sumarize Return str The summary for the input text """ text = text.split() # On défini les paramètres d'entrée pour le modèle device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device) encoder.to(device) decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device) decoder.to(device) # On instancie le modèle model = EncoderDecoderModel(encoder, decoder, vectoriser, device) # model = AutoModel.from_pretrained("EveSa/SummaryProject-LSTM") # model.load_state_dict(torch.load("model/model.pt", map_location=device)) # model.eval() # model.to(device) # On vectorise le texte source = vectoriser.encode(text) source = source.to(device) # On fait passer le texte dans le modèle with torch.no_grad(): output = model(source).to(device) output.to(device) output = output.argmax(dim=-1) return vectoriser.decode(output) # if __name__ == "__main__": # # inference() # print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))