File size: 7,401 Bytes
ad78747
 
 
 
 
 
 
 
 
 
8dba466
ad78747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
3805a61
ad78747
 
 
 
3805a61
ad78747
 
3805a61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
ad78747
 
3805a61
 
 
 
 
 
 
 
 
 
 
ad78747
 
 
 
 
3805a61
 
ad78747
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
"""
    Defines the Encoder, Decoder and Sequence to Sequence models
    used in this projet
"""
import logging

import torch

logging.basicConfig(level=logging.DEBUG)


class Encoder(torch.nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embeddings_dim: int,
        hidden_size: int,
        dropout: int,
        device,
    ):
        # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
        super().__init__()
        self.device = device
        # On ajoute un mot supplémentaire au vocabulaire :
        # on s'en servira pour les mots inconnus
        self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
        self.embeddings.to(device)
        self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
        # Comme on va calculer la log-vraisemblance,
        # c'est le log-softmax qui nous intéresse
        self.dropout = torch.nn.Dropout(dropout)
        self.dropout.to(self.device)
        # Dropout

    def forward(self, inpt):
        inpt.to(self.device)
        emb = self.dropout(self.embeddings(inpt)).to(self.device)
        emb = emb.to(self.device)

        output, (hidden, cell) = self.hidden(emb)
        output.to(self.device)
        hidden = hidden.to(self.device)
        cell = cell.to(self.device)

        return hidden, cell


class Decoder(torch.nn.Module):
    def __init__(
        self,
        vocab_size: int,
        embeddings_dim: int,
        hidden_size: int,
        dropout: int,
        device,
    ):
        # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
        super().__init__()
        self.device = device
        # On ajoute un mot supplémentaire au vocabulaire :
        # on s'en servira pour les mots inconnus
        self.vocab_size = vocab_size
        self.embeddings = torch.nn.Embedding(vocab_size, embeddings_dim)
        self.hidden = torch.nn.LSTM(embeddings_dim, hidden_size, dropout=dropout)
        self.output = torch.nn.Linear(hidden_size, vocab_size)
        # Comme on va calculer la log-vraisemblance,
        # c'est le log-softmax qui nous intéresse
        self.dropout = torch.nn.Dropout(dropout)

    def forward(self, input, hidden, cell):
        input = input.unsqueeze(0)
        input = input.to(self.device)
        emb = self.dropout(self.embeddings(input)).to(self.device)
        emb = emb.to(self.device)

        output, (hidden, cell) = self.hidden(emb, (hidden, cell))
        output = output.to(self.device)
        out = self.output(output.squeeze(0)).to(self.device)
        return out, hidden, cell


class EncoderDecoderModel(torch.nn.Module):
    def __init__(self, encoder, decoder, vectoriser, device):
        # Une idiosyncrasie de torch, pour qu'iel puisse faire sa magie
        super().__init__()
        self.encoder = encoder
        self.decoder = decoder
        self.vectoriser = vectoriser
        self.device = device

    def forward(self, source, num_beams=3, summary_len=0.2):
        """
        :param source: tensor
            the input text
        :param num_beams: int
            the number of outputs to iterate on for beam_search
        :param summary_len: int
            length ratio of the summary compared to the text
        """
        # The ratio must be inferior to 1 to allow text compression
        assert summary_len < 1, f"number lesser than 1 expected, got {summary_len}"

        target_len = int(
            summary_len * source.shape[0]
        )  # Expected summary length (in words)
        target_vocab_size = self.decoder.vocab_size  # Word Embedding length

        # Output of the right format (expected summmary length x word embedding length)
        # filled with zeros. On each iteration, we will replace one of the row of this
        # matrix with the choosen word embedding
        outputs = torch.zeros(target_len, target_vocab_size)

        # put the tensors on the device (useless if CPU bus very useful in case of GPU)
        outputs.to(self.device)
        source.to(self.device)

        # last hidden state of the encoder is used as the initial hidden state of the decoder
        hidden, cell = self.encoder(source)  # Encode the input text
        input = self.vectoriser.encode(
            "<start>"
        )  # Encode the first word of the summary

        # put the tensors on the device
        hidden.to(self.device)
        cell.to(self.device)
        input.to(self.device)

        ### BEAM SEARCH ###
        # If you wonder, b stands for better
        values = None
        b_outputs = torch.zeros(target_len, target_vocab_size).to(self.device)
        b_outputs.to(self.device)

        for i in range(1, target_len):
            # On va déterminer autant de mot que la taille du texte souhaité
            # insert input token embedding, previous hidden and previous cell states
            # receive output tensor (predictions) and new hidden and cell states.

            # replace predictions in a tensor holding predictions for each token
            # logging.debug(f"output : {output}")

            ####### DÉBUT DU BEAM SEARCH ##########
            if values is None:
                # On calcule une première fois les premières probabilité de mot après <start>
                output, hidden, cell = self.decoder(input, hidden, cell)
                output.to(self.device)
                b_hidden = hidden
                b_cell = cell

                # On choisi les k meilleurs scores pour choisir la meilleure probabilité
                # sur deux itérations ensuite
                values, indices = output.topk(num_beams, sorted=True)

            else:
                # On instancie le dictionnaire qui contiendra les scores pour chaque possibilité
                scores = {}

                # Pour chacune des meilleures valeurs, on va calculer l'output
                for value, indice in zip(values, indices):
                    indice.to(self.device)

                    # On calcule l'output
                    b_output, b_hidden, b_cell = self.decoder(indice, b_hidden, b_cell)

                    # On empêche le modèle de se répéter d'un mot sur l'autre en mettant
                    # de force la probabilité du mot précédent à 0
                    b_output[indice] = torch.zeros(1)

                    # On choisit le meilleur résultat pour cette possibilité
                    highest_value = torch.log(b_output).max()

                    # On calcule le score des 2 itérations ensembles
                    score = highest_value * torch.log(value)
                    scores[score] = (b_output, b_hidden, b_cell)

                # On garde le meilleur score sur LES 2 ITÉRATIONS
                b_output, b_hidden, b_cell = scores.get(max(scores))

                # Et du coup on rempli la place de i-1 à la place de i
                b_outputs[i - 1] = b_output.to(self.device)

                # On instancies nos nouvelles valeurs pour la prochaine itération
                values, indices = b_output.topk(num_beams, sorted=True)

            ##################################

            # outputs[i] = output.to(self.device)
            # input = output.argmax(dim=-1).to(self.device)
            # input.to(self.device)

        # logging.debug(f"{vectoriser.decode(outputs.argmax(dim=-1))}")
        return b_outputs.to(self.device)