SummaryProject / src /model.py
EveSa's picture
refactoring de requirements.txt
3c03f61
raw
history blame
No virus
7.54 kB
"""
Defines the Encoder, Decoder and Sequence to Sequence models
used in this projet
"""
import logging
import torch
logging.basicConfig(level=logging.DEBUG)
class Encoder(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embeddings_dim: int,
hidden_size: int,
dropout: int,
device,
):
# Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
super().__init__()
self.device = device
# On ajoute un mot supplémentaire au vocabulaire :
# on s'en servira pour les mots inconnus
self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
self.embeddings.to(device)
self.hidden = torch.nn.LSTM(
embeddings_dim, hidden_size, dropout=dropout)
# Comme on va calculer la log-vraisemblance,
# c'est le log-softmax qui nous intéresse
self.dropout = torch.nn.Dropout(dropout)
self.dropout.to(self.device)
# Dropout
def forward(self, inpt):
inpt.to(self.device)
emb = self.dropout(self.embeddings(inpt)).to(self.device)
emb = emb.to(self.device)
output, (hidden, cell) = self.hidden(emb)
output.to(self.device)
hidden = hidden.to(self.device)
cell = cell.to(self.device)
return hidden, cell
class Decoder(torch.nn.Module):
def __init__(
self,
vocab_size: int,
embeddings_dim: int,
hidden_size: int,
dropout: int,
device,
):
# Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
super().__init__()
self.device = device
# On ajoute un mot supplémentaire au vocabulaire :
# on s'en servira pour les mots inconnus
self.vocab_size = vocab_size
self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
self.hidden = torch.nn.LSTM(
embeddings_dim, hidden_size, dropout=dropout)
self.output = torch.nn.Linear(hidden_size, vocab_size)
# Comme on va calculer la log-vraisemblance,
# c'est le log-softmax qui nous intéresse
self.dropout = torch.nn.Dropout(dropout)
def forward(self, input, hidden, cell):
input = input.unsqueeze(0)
input = input.to(self.device)
emb = self.dropout(self.embeddings(input)).to(self.device)
emb = emb.to(self.device)
output, (hidden, cell) = self.hidden(emb, (hidden, cell))
output = output.to(self.device)
out = self.output(output.squeeze(0)).to(self.device)
return out, hidden, cell
class EncoderDecoderModel(torch.nn.Module):
def __init__(self, encoder, decoder, vectoriser, device):
# Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
super().__init__()
self.encoder = encoder
self.decoder = decoder
self.vectoriser = vectoriser
self.device = device
def forward(self, source, num_beams=3, summary_len=0.2):
"""
:param source: tensor
the input text
:param num_beams: int
the number of outputs to iterate on for beam_search
:param summary_len: int
length ratio of the summary compared to the text
"""
# The ratio must be inferior to 1 to allow text compression
assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"
# Expected summary length (in words)
target_len = int(summary_len * source.shape[0])
# Word Embedding length
target_vocab_size = self.decoder.vocab_size
# Output of the right format (expected summmary length x word
# embedding length) filled with zeros. On each iteration, we
# will replace one of the row of this matrix with the choosen
# word embedding
outputs = torch.zeros(target_len, target_vocab_size)
# put the tensors on the device (useless if CPU bus very useful in
# case of GPU)
outputs.to(self.device)
source.to(self.device)
# last hidden state of the encoder is used
# as the initial hidden state of the decoder
# Encode the input text
hidden, cell = self.encoder(source)
# Encode the first word of the summary
input = self.vectoriser.encode("<start>")
# put the tensors on the device
hidden.to(self.device)
cell.to(self.device)
input.to(self.device)
# BEAM SEARCH #
# If you wonder, b stands for better
values = None
b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
b_outputs.to(self.device)
for i in range(1, target_len):
# On va déterminer autant de mot que la taille du texte souhaité
# insert input token embedding, previous hidden and previous cell states
# receive output tensor (predictions) and new hidden and cell
# states.
# replace predictions in a tensor holding predictions for each token
# logging.debug(f"output : {output}")
####### DÉBUT DU BEAM SEARCH ##########
if values is None:
# On calcule une première fois les premières probabilité de mot
# après <start>
output, hidden, cell = self.decoder(input, hidden, cell)
output.to(self.device)
b_hidden = hidden
b_cell = cell
# On choisi les k meilleurs scores pour choisir la meilleure probabilité
# sur deux itérations ensuite
values, indices = output.topk(num_beams, sorted=True)
else:
# On instancie le dictionnaire qui contiendra les scores pour
# chaque possibilité
scores = {}
# Pour chacune des meilleures valeurs, on va calculer l'output
for value, indice in zip(values, indices):
indice.to(self.device)
# On calcule l'output
b_output, b_hidden, b_cell = self.decoder(
indice, b_hidden, b_cell)
# On empêche le modèle de se répéter d'un mot sur l'autre en mettant
# de force la probabilité du mot précédent à 0
b_output[indice] = torch.zeros(1)
# On choisit le meilleur résultat pour cette possibilité
highest_value = torch.log(b_output).max()
# On calcule le score des 2 itérations ensembles
score = highest_value * torch.log(value)
scores[score] = (b_output, b_hidden, b_cell)
# On garde le meilleur score sur LES 2 ITÉRATIONS
b_output, b_hidden, b_cell = scores.get(max(scores))
# Et du coup on rempli la place de i-1 à la place de i
b_outputs[i - 1] = b_output.to(self.device)
# On instancies nos nouvelles valeurs pour la prochaine
# itération
values, indices = b_output.topk(num_beams, sorted=True)
##################################
# outputs[i] = output.to(self.device)
# input = output.argmax(dim=-1).to(self.device)
# input.to(self.device)
# logging.debug(f"{vectoriser.decode(outputs.argmax(dim=-1))}")
return b_outputs.to(self.device)