""" Allows to predict the summary for a given entry text """ import torch from nltk import word_tokenize import dataloader from model import Decoder, Encoder, EncoderDecoderModel # On doit loader les données pour avoir le Vectoriser > sauvegarder "words" dans un fichiers et le loader par la suite ?? ### À CHANGER POUR N'AVOIR À LOADER QUE LE VECTORISER data1 = dataloader.Data("data/train_extract.jsonl") data2 = dataloader.Data("data/dev_extract.jsonl") train_dataset = data1.make_dataset() dev_dataset = data2.make_dataset() words = data1.get_words() vectoriser = dataloader.Vectoriser(words) word_counts = vectoriser.word_count def inferenceAPI(text: str) -> str: """ Predict the summary for an input text -------- Parameter text: str the text to sumarize Return str The summary for the input text """ text = word_tokenize(text) # On défini les paramètres d'entrée pour le modèle device = torch.device("cuda" if torch.cuda.is_available() else "cpu") encoder = Encoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to( device ) decoder = Decoder(len(vectoriser.idx_to_token) + 1, 256, 512, 0.5, device).to( device ) # On instancie le modèle model = EncoderDecoderModel(encoder, decoder, device) model.load_state_dict(torch.load("model/model.pt", map_location=device)) model.eval() model.to(device) # On vectorise le texte source = vectoriser.encode(text) source = source.to(device) # On fait passer le texte dans le modèle with torch.no_grad(): output = model(source).to(device) output.to(device) return vectoriser.decode(output) # if __name__ == "__main__": # # inference() # print(inferenceAPI("If you choose to use these attributes in logged messages, you need to exercise some care. In the above example, for instance, the Formatter has been set up with a format string which expects ‘clientip’ and ‘user’ in the attribute dictionary of the LogRecord. If these are missing, the message will not be logged because a string formatting exception will occur. So in this case, you always need to pass the extra dictionary with these keys."))