File size: 8,639 Bytes
e8f4897
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
from ..nn import Embedding
from ..nn import utils
import torch
import torch.nn as nn
import torch.nn.functional as F

class Sequence_Tagger(nn.Module):
    def __init__(self, word_dim, num_words, char_dim, num_chars, use_pos, use_char, pos_dim, num_pos,
                 num_filters, kernel_size, rnn_mode, hidden_size, num_layers, tag_space, num_tags,
                 embedd_word=None, embedd_char=None, embedd_pos=None,
                 p_in=0.33, p_out=0.33, p_rnn=(0.33, 0.33),
                 initializer=None):
        super(Sequence_Tagger, self).__init__()
        self.rnn_encoder = BiRecurrentConv_Encoder(word_dim, num_words, char_dim, num_chars, use_pos, use_char,
                                                   pos_dim, num_pos, num_filters,
                                                   kernel_size, rnn_mode, hidden_size,
                                                   num_layers, embedd_word=embedd_word,
                                                   embedd_char=embedd_char, embedd_pos=embedd_pos,
                                                   p_in=p_in, p_out=p_out, p_rnn=p_rnn, initializer=initializer)
        self.sequence_tagger_decoder = Tagger_Decoder(hidden_size, tag_space, num_tags, p_out, initializer)

    def forward(self, input_word, input_char, input_pos, mask=None, length=None, hx=None):
        encoder_output, hn, mask, length = self.rnn_encoder(input_word, input_char, input_pos, mask, length, hx)
        out_counter = self.sequence_tagger_decoder(encoder_output, mask)
        return out_counter, mask, length

    def loss(self, input, target, mask=None, length=None):
        loss_ = self.sequence_tagger_decoder.loss(input, target, mask, length)
        return loss_

    def decode(self, input, mask=None, length=None, leading_symbolic=0):
        out_pred = self.sequence_tagger_decoder.decode(input, mask, leading_symbolic)
        return out_pred

class Tagger_Decoder(nn.Module):
    def __init__(self, hidden_size, tag_space, num_tags, p_out, initializer):
        super(Tagger_Decoder, self).__init__()
        self.criterion_obj = nn.CrossEntropyLoss()
        self.tag_space = tag_space
        self.num_tags = num_tags
        self.p_out = p_out
        self.initializer = initializer
        self.dropout_out = nn.Dropout(p_out)
        self.out_dim = 2 * hidden_size
        self.num_tags = num_tags
        self.fc_1 = nn.Linear(self.out_dim, tag_space)
        self.fc_2 = nn.Linear(tag_space, tag_space//2)
        self.fc_3 = nn.Linear(tag_space//2, num_tags)
        self.reset_parameters()

    def reset_parameters(self):
        if self.initializer is None:
            return
        for name, parameter in self.named_parameters():
            if parameter.dim() == 1:
                parameter.data.zero_()
            else:
                self.initializer(parameter.data)

    def forward(self, input, mask):
        # input from rnn [batch_size, length, hidden_size]
        # [batch_size, length, tag_space]
        output = self.dropout_out(F.elu(self.fc_1(input)))
        #output = self.fc_2(output)
        output = self.dropout_out(F.elu(self.fc_2(output)))
        output = self.fc_3(output)
        return output

    def loss(self, input, target, mask=None, length=None):
        if length is not None:
            max_len = length.max()
            if target.size(1) != max_len:
                target = target[:, :max_len]
        input = input.view(-1, self.num_tags)
        target = target.contiguous().view(-1)
        loss_ = self.criterion_obj(input, target)
        return loss_

    def decode(self, input, mask=None, leading_symbolic=0):
        if mask is not None:
            input = input * mask.unsqueeze(2)
        # remove the first #symbolic rows and columns.
        # now the shape of the input is [n_time_steps, batch_size, t] where t = num_labels - #symbolic.
        input = input[:, :, :-leading_symbolic]
        preds = torch.argmax(input, -1)
        return preds

class BiRecurrentConv_Encoder(nn.Module):
    def __init__(self, word_dim, num_words, char_dim, num_chars, use_pos, use_char, pos_dim, num_pos, num_filters,
                 kernel_size, rnn_mode, hidden_size, num_layers, embedd_word=None, embedd_char=None, embedd_pos=None,
                 p_in=0.33, p_out=0.33, p_rnn=(0.33, 0.33), initializer=None):
        super(BiRecurrentConv_Encoder, self).__init__()
        self.word_embedd = Embedding(num_words, word_dim, init_embedding=embedd_word)
        self.char_embedd = Embedding(num_chars, char_dim, init_embedding=embedd_char) if use_char else None
        self.pos_embedd = Embedding(num_pos, pos_dim, init_embedding=embedd_pos) if use_pos else None
        self.conv1d = nn.Conv1d(char_dim, num_filters, kernel_size, padding=kernel_size - 1) if use_char else None
        # dropout word
        self.dropout_in = nn.Dropout2d(p_in)
        # standard dropout
        self.dropout_out = nn.Dropout2d(p_out)
        self.dropout_rnn_in = nn.Dropout(p_rnn[0])
        self.use_pos = use_pos
        self.use_char = use_char
        self.rnn_mode = rnn_mode
        self.dim_enc = word_dim
        if use_pos:
            self.dim_enc += pos_dim
        if use_char:
            self.dim_enc += num_filters

        if rnn_mode == 'RNN':
            RNN = nn.RNN
            drop_p_rnn = p_rnn[1]
        elif rnn_mode == 'LSTM':
            RNN = nn.LSTM
            drop_p_rnn = p_rnn[1]
        elif rnn_mode == 'GRU':
            RNN = nn.GRU
            drop_p_rnn = p_rnn[1]
        else:
            raise ValueError('Unknown RNN mode: %s' % rnn_mode)
        self.rnn = RNN(self.dim_enc, hidden_size, num_layers=num_layers, batch_first=True, bidirectional=True,
                       dropout=drop_p_rnn)
        self.initializer = initializer
        self.reset_parameters()

    def reset_parameters(self):
        if self.initializer is None:
            return

        for name, parameter in self.named_parameters():
            if name.find('embedd') == -1:
                if parameter.dim() == 1:
                    parameter.data.zero_()
                else:
                    self.initializer(parameter.data)

    def forward(self, input_word, input_char, input_pos, mask=None, length=None, hx=None):
        # hack length from mask
        # we do not hack mask from length for special reasons.
        # Thus, always provide mask if it is necessary.
        if length is None and mask is not None:
            length = mask.data.sum(dim=1).long()

        # [batch_size, length, word_dim]
        word = self.word_embedd(input_word)
        # apply dropout on input
        word = self.dropout_in(word)

        input = word
        if self.use_char:
            # [batch_size, length, char_length, char_dim]
            char = self.char_embedd(input_char)
            char_size = char.size()
            # first transform to [batch *length, char_length, char_dim]
            # then transpose to [batch * length, char_dim, char_length]
            char = char.view(char_size[0] * char_size[1], char_size[2], char_size[3]).transpose(1, 2)
            # put into cnn [batch*length, char_filters, char_length]
            # then put into maxpooling [batch * length, char_filters]
            char, _ = self.conv1d(char).max(dim=2)
            # reshape to [batch_size, length, char_filters]
            char = torch.tanh(char).view(char_size[0], char_size[1], -1)
            # apply dropout on input
            char = self.dropout_in(char)
            # concatenate word and char [batch_size, length, word_dim+char_filter]
            input = torch.cat([input, char], dim=2)

        if self.use_pos:
            # [batch_size, length, pos_dim]
            pos = self.pos_embedd(input_pos)
            # apply dropout on input
            pos = self.dropout_in(pos)
            input = torch.cat([input, pos], dim=2)

        # apply dropout rnn input
        input = self.dropout_rnn_in(input)
        # prepare packed_sequence
        if length is not None:
            seq_input, hx, rev_order, mask = utils.prepare_rnn_seq(input, length, hx=hx, masks=mask, batch_first=True)
            self.rnn.flatten_parameters()
            seq_output, hn = self.rnn(seq_input, hx=hx)
            output, hn = utils.recover_rnn_seq(seq_output, rev_order, hx=hn, batch_first=True)
        else:
            # output from rnn [batch_size, length, hidden_size]
            self.rnn.flatten_parameters()
            output, hn = self.rnn(input, hx=hx)
        # apply dropout for the output of rnn
        output = self.dropout_out(output)
        return output, hn, mask, length