omarmomen commited on
Commit
a1cab64
1 Parent(s): 8c4e1ea
Files changed (3) hide show
  1. config.json +32 -0
  2. pytorch_model.bin +3 -0
  3. structformer_in_parser.py +976 -0
config.json ADDED
@@ -0,0 +1,32 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StructFormer_In_ParserModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "structformer_in_parser.StructFormer_In_ParserConfig",
7
+ "AutoModelForMaskedLM": "structformer_in_parser.StructFormer_In_ParserModel"
8
+ },
9
+ "conv_size": 9,
10
+ "dropatt": 0.1,
11
+ "dropout": 0.1,
12
+ "front_layers": 4,
13
+ "hidden_dropout_prob": 0.1,
14
+ "hidden_size": 512,
15
+ "initializer_range": 0.02,
16
+ "model_type": "structformer_in_parser",
17
+ "n_parser_layers": 3,
18
+ "nhead": 8,
19
+ "nlayers": 8,
20
+ "ntokens": 16000,
21
+ "pad": 1,
22
+ "pos_emb": true,
23
+ "rear_layers": 4,
24
+ "relations": [
25
+ "head",
26
+ "child"
27
+ ],
28
+ "relative_bias": false,
29
+ "torch_dtype": "float32",
30
+ "transformers_version": "4.18.0",
31
+ "weight_act": "softmax"
32
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:bea2e965359c24a325c41425144676e12481794677a98cc5ad661afacee873bf
3
+ size 166262995
structformer_in_parser.py ADDED
@@ -0,0 +1,976 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from torch.nn import init
6
+ from transformers import PretrainedConfig, PreTrainedModel
7
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
8
+
9
+ ##########################################
10
+ def _get_activation_fn(activation):
11
+ """Get specified activation function."""
12
+ if activation == "relu":
13
+ return nn.ReLU()
14
+ elif activation == "gelu":
15
+ return nn.GELU()
16
+ elif activation == "leakyrelu":
17
+ return nn.LeakyReLU()
18
+
19
+ raise RuntimeError(
20
+ "activation should be relu/gelu, not {}".format(activation))
21
+
22
+
23
+ class Conv1d(nn.Module):
24
+ """1D convolution layer."""
25
+
26
+ def __init__(self, hidden_size, kernel_size, dilation=1):
27
+ """Initialization.
28
+
29
+ Args:
30
+ hidden_size: dimension of input embeddings
31
+ kernel_size: convolution kernel size
32
+ dilation: the spacing between the kernel points
33
+ """
34
+ super(Conv1d, self).__init__()
35
+
36
+ if kernel_size % 2 == 0:
37
+ padding = (kernel_size // 2) * dilation
38
+ self.shift = True
39
+ else:
40
+ padding = ((kernel_size - 1) // 2) * dilation
41
+ self.shift = False
42
+ self.conv = nn.Conv1d(
43
+ hidden_size,
44
+ hidden_size,
45
+ kernel_size,
46
+ padding=padding,
47
+ dilation=dilation)
48
+
49
+ def forward(self, x):
50
+ """Compute convolution.
51
+
52
+ Args:
53
+ x: input embeddings
54
+ Returns:
55
+ conv_output: convolution results
56
+ """
57
+
58
+ if self.shift:
59
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
60
+ else:
61
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
62
+
63
+
64
+ class MultiheadAttention(nn.Module):
65
+ """Multi-head self-attention layer."""
66
+
67
+ def __init__(self,
68
+ embed_dim,
69
+ num_heads,
70
+ dropout=0.,
71
+ bias=True,
72
+ v_proj=True,
73
+ out_proj=True,
74
+ relative_bias=True):
75
+ """Initialization.
76
+
77
+ Args:
78
+ embed_dim: dimension of input embeddings
79
+ num_heads: number of self-attention heads
80
+ dropout: dropout rate
81
+ bias: bool, indicate whether include bias for linear transformations
82
+ v_proj: bool, indicate whether project inputs to new values
83
+ out_proj: bool, indicate whether project outputs to new values
84
+ relative_bias: bool, indicate whether use a relative position based
85
+ attention bias
86
+ """
87
+
88
+ super(MultiheadAttention, self).__init__()
89
+ self.embed_dim = embed_dim
90
+
91
+ self.num_heads = num_heads
92
+ self.drop = nn.Dropout(dropout)
93
+ self.head_dim = embed_dim // num_heads
94
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
95
+ "divisible by "
96
+ "num_heads")
97
+
98
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
99
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
100
+ if v_proj:
101
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
102
+ else:
103
+ self.v_proj = nn.Identity()
104
+
105
+ if out_proj:
106
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
107
+ else:
108
+ self.out_proj = nn.Identity()
109
+
110
+ if relative_bias:
111
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
112
+ else:
113
+ self.relative_bias = None
114
+
115
+ self._reset_parameters()
116
+
117
+ def _reset_parameters(self):
118
+ """Initialize attention parameters."""
119
+
120
+ init.xavier_uniform_(self.q_proj.weight)
121
+ init.constant_(self.q_proj.bias, 0.)
122
+
123
+ init.xavier_uniform_(self.k_proj.weight)
124
+ init.constant_(self.k_proj.bias, 0.)
125
+
126
+ if isinstance(self.v_proj, nn.Linear):
127
+ init.xavier_uniform_(self.v_proj.weight)
128
+ init.constant_(self.v_proj.bias, 0.)
129
+
130
+ if isinstance(self.out_proj, nn.Linear):
131
+ init.xavier_uniform_(self.out_proj.weight)
132
+ init.constant_(self.out_proj.bias, 0.)
133
+
134
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
135
+ """Compute multi-head self-attention.
136
+
137
+ Args:
138
+ query: input embeddings
139
+ key_padding_mask: 3D mask that prevents attention to certain positions
140
+ attn_mask: 3D mask that rescale the attention weight at each position
141
+ Returns:
142
+ attn_output: self-attention output
143
+ """
144
+
145
+ length, bsz, embed_dim = query.size()
146
+ assert embed_dim == self.embed_dim
147
+
148
+ head_dim = embed_dim // self.num_heads
149
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
150
+ "divisible by num_heads")
151
+ scaling = float(head_dim)**-0.5
152
+
153
+ q = self.q_proj(query)
154
+ k = self.k_proj(query)
155
+ v = self.v_proj(query)
156
+
157
+ q = q * scaling
158
+
159
+ if attn_mask is not None:
160
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
161
+ query.size(0), query.size(0)]
162
+
163
+ q = q.contiguous().view(length, bsz * self.num_heads,
164
+ head_dim).transpose(0, 1)
165
+ k = k.contiguous().view(length, bsz * self.num_heads,
166
+ head_dim).transpose(0, 1)
167
+ v = v.contiguous().view(length, bsz * self.num_heads,
168
+ head_dim).transpose(0, 1)
169
+
170
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
171
+ assert list(
172
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
173
+
174
+ if self.relative_bias is not None:
175
+ pos = torch.arange(length, device=query.device)
176
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
177
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
178
+ -1)
179
+
180
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
181
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
182
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
183
+ attn_output_weights = attn_output_weights + relative_bias
184
+
185
+ if key_padding_mask is not None:
186
+ attn_output_weights = attn_output_weights + key_padding_mask
187
+
188
+ if attn_mask is None:
189
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
190
+ else:
191
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
192
+
193
+ attn_output_weights = self.drop(attn_output_weights)
194
+
195
+ attn_output = torch.bmm(attn_output_weights, v)
196
+
197
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
198
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
199
+ length, bsz, embed_dim)
200
+ attn_output = self.out_proj(attn_output)
201
+
202
+ return attn_output
203
+
204
+
205
+ class TransformerLayer(nn.Module):
206
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
207
+
208
+ def __init__(self,
209
+ d_model,
210
+ nhead,
211
+ dim_feedforward=2048,
212
+ dropout=0.1,
213
+ dropatt=0.1,
214
+ activation="leakyrelu",
215
+ relative_bias=True):
216
+ """Initialization.
217
+
218
+ Args:
219
+ d_model: dimension of inputs
220
+ nhead: number of self-attention heads
221
+ dim_feedforward: dimension of hidden layer in feedforward layer
222
+ dropout: dropout rate
223
+ dropatt: drop attention rate
224
+ activation: activation function
225
+ relative_bias: bool, indicate whether use a relative position based
226
+ attention bias
227
+ """
228
+
229
+ super(TransformerLayer, self).__init__()
230
+ self.self_attn = MultiheadAttention(
231
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
232
+ # Implementation of Feedforward model
233
+ self.feedforward = nn.Sequential(
234
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
235
+ _get_activation_fn(activation), nn.Dropout(dropout),
236
+ nn.Linear(dim_feedforward, d_model))
237
+
238
+ self.norm = nn.LayerNorm(d_model)
239
+ self.dropout1 = nn.Dropout(dropout)
240
+ self.dropout2 = nn.Dropout(dropout)
241
+
242
+ self.nhead = nhead
243
+
244
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
245
+ """Pass the input through the encoder layer.
246
+
247
+ Args:
248
+ src: the sequence to the encoder layer (required).
249
+ attn_mask: the mask for the src sequence (optional).
250
+ key_padding_mask: the mask for the src keys per batch (optional).
251
+ Returns:
252
+ src3: the output of transformer layer, share the same shape as src.
253
+ """
254
+ src2 = self.self_attn(
255
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
256
+ src2 = src + self.dropout1(src2)
257
+ src3 = self.feedforward(src2)
258
+ src3 = src2 + self.dropout2(src3)
259
+
260
+ return src3
261
+
262
+ ##########################################
263
+ def cumprod(x, reverse=False, exclusive=False):
264
+ """cumulative product."""
265
+ if reverse:
266
+ x = x.flip([-1])
267
+
268
+ if exclusive:
269
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
270
+
271
+ cx = x.cumprod(-1)
272
+
273
+ if reverse:
274
+ cx = cx.flip([-1])
275
+ return cx
276
+
277
+
278
+ def cumsum(x, reverse=False, exclusive=False):
279
+ """cumulative sum."""
280
+ bsz, _, length = x.size()
281
+ device = x.device
282
+ if reverse:
283
+ if exclusive:
284
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
285
+ else:
286
+ w = torch.ones([bsz, length, length], device=device).tril(0)
287
+ cx = torch.bmm(x, w)
288
+ else:
289
+ if exclusive:
290
+ w = torch.ones([bsz, length, length], device=device).triu(1)
291
+ else:
292
+ w = torch.ones([bsz, length, length], device=device).triu(0)
293
+ cx = torch.bmm(x, w)
294
+ return cx
295
+
296
+
297
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
298
+ """cumulative min."""
299
+ if reverse:
300
+ if exclusive:
301
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
302
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
303
+ else:
304
+ if exclusive:
305
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
306
+ x = x.cummin(-1)[0]
307
+ return x
308
+
309
+
310
+ class Transformer_Front(nn.Module):
311
+ """Transformer model."""
312
+
313
+ def __init__(self,
314
+ hidden_size,
315
+ nlayers,
316
+ ntokens,
317
+ nhead=8,
318
+ dropout=0.1,
319
+ dropatt=0.1,
320
+ relative_bias=True,
321
+ pos_emb=False,
322
+ pad=0):
323
+ """Initialization.
324
+
325
+ Args:
326
+ hidden_size: dimension of inputs and hidden states
327
+ nlayers: number of layers
328
+ ntokens: number of output categories
329
+ nhead: number of self-attention heads
330
+ dropout: dropout rate
331
+ dropatt: drop attention rate
332
+ relative_bias: bool, indicate whether use a relative position based
333
+ attention bias
334
+ pos_emb: bool, indicate whether use a learnable positional embedding
335
+ pad: pad token index
336
+ """
337
+
338
+ super(Transformer_Front, self).__init__()
339
+
340
+ self.drop = nn.Dropout(dropout)
341
+
342
+ self.emb = nn.Embedding(ntokens, hidden_size)
343
+ if pos_emb:
344
+ self.pos_emb = nn.Embedding(500, hidden_size)
345
+
346
+ self.layers = nn.ModuleList([
347
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
348
+ dropatt=dropatt, relative_bias=relative_bias)
349
+ for _ in range(nlayers)])
350
+
351
+ self.norm = nn.LayerNorm(hidden_size)
352
+
353
+ self.init_weights()
354
+
355
+ self.nlayers = nlayers
356
+ self.nhead = nhead
357
+ self.ntokens = ntokens
358
+ self.hidden_size = hidden_size
359
+ self.pad = pad
360
+
361
+ def init_weights(self):
362
+ """Initialize token embedding and output bias."""
363
+ initrange = 0.1
364
+ self.emb.weight.data.uniform_(-initrange, initrange)
365
+ if hasattr(self, 'pos_emb'):
366
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
367
+
368
+
369
+ def visibility(self, x, device):
370
+ """Mask pad tokens."""
371
+ visibility = (x != self.pad).float()
372
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
373
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
374
+ return visibility.log()
375
+
376
+ def encode(self, x, pos):
377
+ """Standard transformer encode process."""
378
+ h = self.emb(x)
379
+ if hasattr(self, 'pos_emb'):
380
+ h = h + self.pos_emb(pos)
381
+ h_list = []
382
+ visibility = self.visibility(x, x.device)
383
+
384
+ for i in range(self.nlayers):
385
+ h_list.append(h)
386
+ h = self.layers[i](
387
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
388
+
389
+ output = h
390
+ h_array = torch.stack(h_list, dim=2)
391
+
392
+ return output, h_array
393
+
394
+ def forward(self, x, pos):
395
+ """Pass the input through the encoder layer.
396
+
397
+ Args:
398
+ x: input tokens (required).
399
+ pos: position for each token (optional).
400
+ Returns:
401
+ output: probability distributions for missing tokens.
402
+ state_dict: parsing results and raw output
403
+ """
404
+
405
+ batch_size, length = x.size()
406
+
407
+ raw_output, _ = self.encode(x, pos)
408
+ raw_output = self.norm(raw_output)
409
+ raw_output = self.drop(raw_output)
410
+
411
+ return {'raw_output': raw_output}
412
+
413
+
414
+ class Transformer_Rear(nn.Module):
415
+ """Transformer model."""
416
+
417
+ def __init__(self,
418
+ hidden_size,
419
+ nlayers,
420
+ ntokens,
421
+ nhead=8,
422
+ dropout=0.1,
423
+ dropatt=0.1,
424
+ relative_bias=True,
425
+ pos_emb=False,
426
+ pad=0):
427
+ """Initialization.
428
+
429
+ Args:
430
+ hidden_size: dimension of inputs and hidden states
431
+ nlayers: number of layers
432
+ ntokens: number of output categories
433
+ nhead: number of self-attention heads
434
+ dropout: dropout rate
435
+ dropatt: drop attention rate
436
+ relative_bias: bool, indicate whether use a relative position based
437
+ attention bias
438
+ pos_emb: bool, indicate whether use a learnable positional embedding
439
+ pad: pad token index
440
+ """
441
+
442
+ super(Transformer_Rear, self).__init__()
443
+
444
+ self.drop = nn.Dropout(dropout)
445
+
446
+ self.emb = nn.Embedding(ntokens, hidden_size)
447
+ if pos_emb:
448
+ self.pos_emb = nn.Embedding(500, hidden_size)
449
+
450
+ self.layers = nn.ModuleList([
451
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
452
+ dropatt=dropatt, relative_bias=relative_bias)
453
+ for _ in range(nlayers)])
454
+
455
+ self.norm = nn.LayerNorm(hidden_size)
456
+
457
+ self.output_layer = nn.Linear(hidden_size, ntokens)
458
+
459
+ self.init_weights()
460
+
461
+ self.nlayers = nlayers
462
+ self.nhead = nhead
463
+ self.ntokens = ntokens
464
+ self.hidden_size = hidden_size
465
+ self.pad = pad
466
+
467
+ def init_weights(self):
468
+ """Initialize token embedding and output bias."""
469
+ initrange = 0.1
470
+ self.emb.weight.data.uniform_(-initrange, initrange)
471
+ if hasattr(self, 'pos_emb'):
472
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
473
+ self.output_layer.bias.data.fill_(0)
474
+
475
+ def visibility(self, x, device):
476
+ """Mask pad tokens."""
477
+ visibility = (x != self.pad).float()
478
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
479
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
480
+ return visibility.log()
481
+
482
+ def encode(self, x, pos, att_mask, h):
483
+ """Structformer encoding process."""
484
+
485
+ visibility = self.visibility(x, x.device)
486
+
487
+ if hasattr(self, 'pos_emb'):
488
+ assert pos.max() < 500
489
+ h = h + self.pos_emb(pos)
490
+ for i in range(self.nlayers):
491
+ h = self.layers[i](
492
+ h.transpose(0, 1), attn_mask=att_mask[i],
493
+ key_padding_mask=visibility).transpose(0, 1)
494
+ return h
495
+
496
+ def forward(self, x, pos):
497
+ """Pass the input through the encoder layer.
498
+
499
+ Args:
500
+ x: input tokens (required).
501
+ pos: position for each token (optional).
502
+ Returns:
503
+ output: probability distributions for missing tokens.
504
+ state_dict: parsing results and raw output
505
+ """
506
+
507
+ batch_size, length = x.size()
508
+
509
+ raw_output, _ = self.encode(x, pos)
510
+ raw_output = self.norm(raw_output)
511
+ raw_output = self.drop(raw_output)
512
+
513
+ output = self.output_layer(raw_output)
514
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
515
+
516
+
517
+ class StructFormer_In_Parser(nn.Module):
518
+ """StructFormer model."""
519
+
520
+ def __init__(self,
521
+ hidden_size,
522
+ nlayers,
523
+ ntokens,
524
+ nhead=8,
525
+ dropout=0.1,
526
+ dropatt=0.1,
527
+ relative_bias=False,
528
+ pos_emb=False,
529
+ front_layers=2,
530
+ rear_layers=6,
531
+ pad=0,
532
+ n_parser_layers=4,
533
+ conv_size=9,
534
+ relations=('head', 'child'),
535
+ weight_act='softmax'):
536
+ """Initialization.
537
+
538
+ Args:
539
+ hidden_size: dimension of inputs and hidden states
540
+ nlayers: number of layers
541
+ ntokens: number of output categories
542
+ nhead: number of self-attention heads
543
+ dropout: dropout rate
544
+ dropatt: drop attention rate
545
+ relative_bias: bool, indicate whether use a relative position based
546
+ attention bias
547
+ pos_emb: bool, indicate whether use a learnable positional embedding
548
+ pad: pad token index
549
+ n_parser_layers: number of parsing layers
550
+ conv_size: convolution kernel size for parser
551
+ relations: relations that are used to compute self attention
552
+ weight_act: relations distribution activation function
553
+ """
554
+
555
+ super(StructFormer_In_Parser, self).__init__()
556
+
557
+ self.transformer_front = Transformer_Front(
558
+ hidden_size,
559
+ nlayers=front_layers,
560
+ ntokens=ntokens,
561
+ nhead=nhead,
562
+ dropout=dropout,
563
+ dropatt=dropatt,
564
+ relative_bias=relative_bias,
565
+ pos_emb=pos_emb,
566
+ pad=pad
567
+ )
568
+
569
+ self.transformer_rear = Transformer_Rear(
570
+ hidden_size,
571
+ nlayers=rear_layers,
572
+ ntokens=ntokens,
573
+ nhead=nhead,
574
+ dropout=dropout,
575
+ dropatt=dropatt,
576
+ relative_bias=relative_bias,
577
+ pos_emb=pos_emb,
578
+ pad=pad
579
+ )
580
+ self.transformer_rear.emb.weight = self.transformer_front.emb.weight
581
+ self.transformer_rear.output_layer.weight = self.transformer_front.emb.weight
582
+ if pos_emb:
583
+ self.transformer_rear.pos_emb.weight = self.transformer_front.pos_emb.weight
584
+
585
+ self.parser_layers = nn.ModuleList([
586
+ nn.Sequential(Conv1d(hidden_size, conv_size),
587
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
588
+ nn.Tanh()) for i in range(n_parser_layers)])
589
+
590
+ self.distance_ff = nn.Sequential(
591
+ Conv1d(hidden_size, 2),
592
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
593
+ nn.Linear(hidden_size, 1))
594
+
595
+ self.height_ff = nn.Sequential(
596
+ nn.Linear(hidden_size, hidden_size),
597
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
598
+ nn.Linear(hidden_size, 1))
599
+
600
+ n_rel = len(relations)
601
+ self._rel_weight = nn.Parameter(torch.zeros((self.transformer_rear.nlayers, nhead, n_rel)))
602
+ self._rel_weight.data.normal_(0, 0.1)
603
+
604
+ self._scaler = nn.Parameter(torch.zeros(2))
605
+
606
+ self.n_parse_layers = n_parser_layers
607
+ self.weight_act = weight_act
608
+ self.relations = relations
609
+
610
+ @property
611
+ def scaler(self):
612
+ return self._scaler.exp()
613
+
614
+ @property
615
+ def rel_weight(self):
616
+ if self.weight_act == 'sigmoid':
617
+ return torch.sigmoid(self._rel_weight)
618
+ elif self.weight_act == 'softmax':
619
+ return torch.softmax(self._rel_weight, dim=-1)
620
+
621
+ def parse(self, x, h):
622
+ """Parse input sentence.
623
+
624
+ Args:
625
+ x: input tokens (required).
626
+ pos: position for each token (optional).
627
+ Returns:
628
+ distance: syntactic distance
629
+ height: syntactic height
630
+ """
631
+
632
+ mask = (x != self.transformer_rear.pad)
633
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
634
+
635
+ for i in range(self.n_parse_layers):
636
+ h = h.masked_fill(~mask[:, :, None], 0)
637
+ h = self.parser_layers[i](h)
638
+
639
+ height = self.height_ff(h).squeeze(-1)
640
+ height.masked_fill_(~mask, -1e9)
641
+
642
+ distance = self.distance_ff(h).squeeze(-1)
643
+ distance.masked_fill_(~mask_shifted, 1e9)
644
+
645
+ # Calbrating the distance and height to the same level
646
+ length = distance.size(1)
647
+ height_max = height[:, None, :].expand(-1, length, -1)
648
+ height_max = torch.cummax(
649
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
650
+ dim=-1)[0].triu(0)
651
+
652
+ margin_left = torch.relu(
653
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
654
+ margin_right = torch.relu(distance[:, None, :] - height_max)
655
+ margin = torch.where(margin_left > margin_right, margin_right,
656
+ margin_left).triu(0)
657
+
658
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
659
+ margin.masked_fill_(~margin_mask, 0)
660
+ margin = margin.max()
661
+
662
+ distance = distance - margin
663
+
664
+ return distance, height
665
+
666
+ def compute_block(self, distance, height):
667
+ """Compute constituents from distance and height."""
668
+
669
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
670
+
671
+ gamma = torch.sigmoid(-beta_logits)
672
+ ones = torch.ones_like(gamma)
673
+
674
+ block_mask_left = cummin(
675
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
676
+ block_mask_left = block_mask_left - F.pad(
677
+ block_mask_left[:, :, :-1], (1, 0), value=0)
678
+ block_mask_left.tril_(0)
679
+
680
+ block_mask_right = cummin(
681
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
682
+ block_mask_right = block_mask_right - F.pad(
683
+ block_mask_right[:, :, 1:], (0, 1), value=0)
684
+ block_mask_right.triu_(0)
685
+
686
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
687
+ block = cumsum(block_mask_left).tril(0) + cumsum(
688
+ block_mask_right, reverse=True).triu(1)
689
+
690
+ return block_p, block
691
+
692
+ def compute_head(self, height):
693
+ """Estimate head for each constituent."""
694
+
695
+ _, length = height.size()
696
+ head_logits = height * self.scaler[1]
697
+ index = torch.arange(length, device=height.device)
698
+
699
+ mask = (index[:, None, None] <= index[None, None, :]) * (
700
+ index[None, None, :] <= index[None, :, None])
701
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
702
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
703
+
704
+ head_p = torch.softmax(head_logits, dim=-1)
705
+
706
+ return head_p
707
+
708
+ def generate_mask(self, x, distance, height):
709
+ """Compute head and cibling distribution for each token."""
710
+
711
+ bsz, length = x.size()
712
+
713
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
714
+ eye = eye[None, :, :].expand((bsz, -1, -1))
715
+
716
+ block_p, block = self.compute_block(distance, height)
717
+ head_p = self.compute_head(height)
718
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
719
+ head = head.masked_fill(eye, 0)
720
+ child = head.transpose(1, 2)
721
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
722
+
723
+ rel_list = []
724
+ if 'head' in self.relations:
725
+ rel_list.append(head)
726
+ if 'child' in self.relations:
727
+ rel_list.append(child)
728
+ if 'cibling' in self.relations:
729
+ rel_list.append(cibling)
730
+
731
+ rel = torch.stack(rel_list, dim=1)
732
+
733
+ rel_weight = self.rel_weight
734
+
735
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
736
+ att_mask = dep.reshape(self.transformer_rear.nlayers, bsz * self.transformer_rear.nhead, length, length)
737
+
738
+ return att_mask, cibling, head, block
739
+
740
+ def forward(self, x, pos):
741
+ """Pass the input through the encoder layer.
742
+
743
+ Args:
744
+ x: input tokens (required).
745
+ pos: position for each token (optional).
746
+ Returns:
747
+ output: probability distributions for missing tokens.
748
+ state_dict: parsing results and raw output
749
+ """
750
+
751
+ batch_size, length = x.size()
752
+
753
+ raw_output_1, _ = self.transformer_front.encode(x, pos)
754
+ raw_output_1 = self.transformer_front.norm(raw_output_1)
755
+ raw_output_1 = self.transformer_front.drop(raw_output_1)
756
+
757
+ distance, height = self.parse(x, raw_output_1)
758
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
759
+
760
+ raw_output_2 = self.transformer_rear.encode(x, pos, att_mask, raw_output_1)
761
+ raw_output_2 = self.transformer_rear.norm(raw_output_2)
762
+ raw_output_2 = self.transformer_rear.drop(raw_output_2)
763
+
764
+ output = self.transformer_rear.output_layer(raw_output_2)
765
+
766
+ return output.view(batch_size * length, -1), \
767
+ {'raw_output': raw_output_2, 'distance': distance, 'height': height,
768
+ 'cibling': cibling, 'head': head, 'block': block}
769
+
770
+
771
+
772
+ ##########################################
773
+ # Clasication Head For BabyLM Evaluation Tasks
774
+ ##########################################
775
+ class ClassificationHead(nn.Module):
776
+ """Head for sentence-level classification tasks."""
777
+ def __init__(self, config):
778
+ super(ClassificationHead, self).__init__()
779
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
780
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
781
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
782
+
783
+ def forward(self, features, **kwargs):
784
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
785
+ x = self.dropout(x)
786
+ x = self.dense(x)
787
+ x = torch.tanh(x)
788
+ x = self.dropout(x)
789
+ x = self.out_proj(x)
790
+ return x
791
+
792
+ ##########################################
793
+ # HuggingFace Config
794
+ ##########################################
795
+ class StructFormer_In_ParserConfig(PretrainedConfig):
796
+ model_type = "structformer_in_parser"
797
+
798
+ def __init__(
799
+ self,
800
+ hidden_size=512,
801
+ nlayers=8,
802
+ ntokens=10_000,
803
+ nhead=8,
804
+ dropout=0.1,
805
+ dropatt=0.1,
806
+ relative_bias=False,
807
+ pos_emb=False,
808
+ pad=0,
809
+ n_parser_layers=4,
810
+ front_layers=2,
811
+ rear_layers=6,
812
+ conv_size=9,
813
+ relations=('head', 'child'),
814
+ weight_act='softmax',
815
+ num_labels=1,
816
+ hidden_dropout_prob=0.1,
817
+ initializer_range=0.02,
818
+ **kwargs,
819
+ ):
820
+ self.hidden_size = hidden_size
821
+ self.nlayers = nlayers
822
+ self.ntokens = ntokens
823
+ self.nhead = nhead
824
+ self.dropout = dropout
825
+ self.dropatt = dropatt
826
+ self.relative_bias = relative_bias
827
+ self.pos_emb = pos_emb
828
+ self.pad = pad
829
+ self.n_parser_layers = n_parser_layers
830
+ self.front_layers = front_layers
831
+ self.rear_layers = rear_layers
832
+ self.conv_size = conv_size
833
+ self.relations = relations
834
+ self.weight_act = weight_act
835
+ self.num_labels = num_labels
836
+ self.hidden_dropout_prob = hidden_dropout_prob
837
+ self.initializer_range=initializer_range
838
+ super().__init__(**kwargs)
839
+
840
+
841
+ ##########################################
842
+ # HuggingFace Models
843
+ ##########################################
844
+ class StructFormer_In_ParserModel(PreTrainedModel):
845
+ config_class = StructFormer_In_ParserConfig
846
+
847
+ def __init__(self, config):
848
+ super().__init__(config)
849
+ self.model = StructFormer_In_Parser(
850
+ hidden_size=config.hidden_size,
851
+ nlayers=config.nlayers,
852
+ ntokens=config.ntokens,
853
+ nhead=config.nhead,
854
+ dropout=config.dropout,
855
+ dropatt=config.dropatt,
856
+ relative_bias=config.relative_bias,
857
+ pos_emb=config.pos_emb,
858
+ pad=config.pad,
859
+ n_parser_layers=config.n_parser_layers,
860
+ front_layers=config.front_layers,
861
+ rear_layers=config.rear_layers,
862
+ conv_size=config.conv_size,
863
+ relations=config.relations,
864
+ weight_act=config.weight_act
865
+ )
866
+ self.config = config
867
+
868
+ def parse(self, input_ids, **kwargs):
869
+ x = input_ids
870
+ batch_size, length = x.size()
871
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
872
+
873
+ sf_output = self.model(x, pos)
874
+
875
+ return sf_output[1]
876
+
877
+ def forward(self, input_ids, labels=None, **kwargs):
878
+ x = input_ids
879
+ batch_size, length = x.size()
880
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
881
+
882
+ sf_output = self.model(x, pos)
883
+
884
+ loss = None
885
+ if labels is not None:
886
+ loss_fct = nn.CrossEntropyLoss()
887
+ loss = loss_fct(sf_output[0], labels.reshape(-1))
888
+
889
+ return MaskedLMOutput(
890
+ loss=loss, # shape: 1
891
+ logits=sf_output[0].view(batch_size, length, -1), # shape: (batch_size, length, ntokens)
892
+ hidden_states=None,
893
+ attentions=None
894
+ )
895
+
896
+ class StructFormer_In_ParserModelForSequenceClassification(PreTrainedModel):
897
+ config_class = StructFormer_In_ParserConfig
898
+
899
+ def __init__(self, config):
900
+ super().__init__(config)
901
+ self.model = StructFormer_In_Parser(
902
+ hidden_size=config.hidden_size,
903
+ nlayers=config.nlayers,
904
+ ntokens=config.ntokens,
905
+ nhead=config.nhead,
906
+ dropout=config.dropout,
907
+ dropatt=config.dropatt,
908
+ relative_bias=config.relative_bias,
909
+ pos_emb=config.pos_emb,
910
+ pad=config.pad,
911
+ n_parser_layers=config.n_parser_layers,
912
+ front_layers=config.front_layers,
913
+ rear_layers=config.rear_layers,
914
+ conv_size=config.conv_size,
915
+ relations=config.relations,
916
+ weight_act=config.weight_act
917
+ )
918
+ self.config = config
919
+ self.num_labels = config.num_labels
920
+ self.model.classifier = ClassificationHead(config)
921
+
922
+ def _init_weights(self, module):
923
+ """Initialize the weights"""
924
+ if isinstance(module, nn.Linear):
925
+ # Slightly different from the TF version which uses truncated_normal for initialization
926
+ # cf https://github.com/pytorch/pytorch/pull/5617
927
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
928
+ if module.bias is not None:
929
+ module.bias.data.zero_()
930
+ elif isinstance(module, nn.Embedding):
931
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
932
+ if module.padding_idx is not None:
933
+ module.weight.data[module.padding_idx].zero_()
934
+ elif isinstance(module, nn.LayerNorm):
935
+ if module.bias is not None:
936
+ module.bias.data.zero_()
937
+ module.weight.data.fill_(1.0)
938
+
939
+ def forward(self, input_ids, labels=None, **kwargs):
940
+ x = input_ids
941
+ batch_size, length = x.size()
942
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
943
+
944
+ sf_output = self.model(x, pos)
945
+
946
+ logits = self.model.classifier(sf_output[1]['raw_output'])
947
+ loss = None
948
+ if labels is not None:
949
+ if self.config.problem_type is None:
950
+ if self.num_labels == 1:
951
+ self.config.problem_type = "regression"
952
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
953
+ self.config.problem_type = "single_label_classification"
954
+ else:
955
+ self.config.problem_type = "multi_label_classification"
956
+
957
+ if self.config.problem_type == "regression":
958
+ loss_fct = nn.MSELoss()
959
+ if self.num_labels == 1:
960
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
961
+ else:
962
+ loss = loss_fct(logits, labels)
963
+ elif self.config.problem_type == "single_label_classification":
964
+ loss_fct = nn.CrossEntropyLoss()
965
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
966
+ elif self.config.problem_type == "multi_label_classification":
967
+ loss_fct = nn.BCEWithLogitsLoss()
968
+ loss = loss_fct(logits, labels)
969
+
970
+ return SequenceClassifierOutput(
971
+ loss=loss,
972
+ logits=logits,
973
+ hidden_states=None,
974
+ attentions=None,
975
+ )
976
+