omarmomen commited on
Commit
bdd1ce1
1 Parent(s): 9ef743d
Files changed (3) hide show
  1. config.json +29 -0
  2. pytorch_model.bin +3 -0
  3. structformer_as_hf.py +764 -0
config.json ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StructformerModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "structformer_as_hf.StructformerConfig",
7
+ "AutoModelForMaskedLM": "structformer_as_hf.StructformerModel"
8
+ },
9
+ "conv_size": 9,
10
+ "dropatt": 0.1,
11
+ "dropout": 0.1,
12
+ "hidden_size": 768,
13
+ "model_type": "structformer",
14
+ "n_context_layers": 0,
15
+ "n_parser_layers": 4,
16
+ "nhead": 12,
17
+ "nlayers": 12,
18
+ "ntokens": 32000,
19
+ "pad": 0,
20
+ "pos_emb": true,
21
+ "relations": [
22
+ "head",
23
+ "child"
24
+ ],
25
+ "relative_bias": false,
26
+ "torch_dtype": "float32",
27
+ "transformers_version": "4.18.0",
28
+ "weight_act": "softmax"
29
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b053d2370393fe370c6ca640bab063dd23fd3ad1d8e65cc1b028dbdce4c2f509
3
+ size 532296211
structformer_as_hf.py ADDED
@@ -0,0 +1,764 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from transformers import PreTrainedModel
6
+ from transformers import PretrainedConfig
7
+ from transformers.modeling_outputs import MaskedLMOutput
8
+ from typing import List
9
+
10
+ ##########################################
11
+ # HuggingFace Config
12
+ ##########################################
13
+ class StructformerConfig(PretrainedConfig):
14
+ model_type = "structformer"
15
+
16
+ def __init__(
17
+ self,
18
+ hidden_size=768,
19
+ n_context_layers=2,
20
+ nlayers=6,
21
+ ntokens=32000,
22
+ nhead=8,
23
+ dropout=0.1,
24
+ dropatt=0.1,
25
+ relative_bias=False,
26
+ pos_emb=False,
27
+ pad=0,
28
+ n_parser_layers=4,
29
+ conv_size=9,
30
+ relations=('head', 'child'),
31
+ weight_act='softmax',
32
+ **kwargs,
33
+ ):
34
+ self.hidden_size = hidden_size
35
+ self.n_context_layers = n_context_layers
36
+ self.nlayers = nlayers
37
+ self.ntokens = ntokens
38
+ self.nhead = nhead
39
+ self.dropout = dropout
40
+ self.dropatt = dropatt
41
+ self.relative_bias = relative_bias
42
+ self.pos_emb = pos_emb
43
+ self.pad = pad
44
+ self.n_parser_layers = n_parser_layers
45
+ self.conv_size = conv_size
46
+ self.relations = relations
47
+ self.weight_act = weight_act
48
+ super().__init__(**kwargs)
49
+
50
+ ##########################################
51
+ # Custom Layers
52
+ ##########################################
53
+ def _get_activation_fn(activation):
54
+ """Get specified activation function."""
55
+ if activation == "relu":
56
+ return nn.ReLU()
57
+ elif activation == "gelu":
58
+ return nn.GELU()
59
+ elif activation == "leakyrelu":
60
+ return nn.LeakyReLU()
61
+
62
+ raise RuntimeError(
63
+ "activation should be relu/gelu, not {}".format(activation))
64
+
65
+ class Conv1d(nn.Module):
66
+ """1D convolution layer."""
67
+
68
+ def __init__(self, hidden_size, kernel_size, dilation=1):
69
+ """Initialization.
70
+
71
+ Args:
72
+ hidden_size: dimension of input embeddings
73
+ kernel_size: convolution kernel size
74
+ dilation: the spacing between the kernel points
75
+ """
76
+ super(Conv1d, self).__init__()
77
+
78
+ if kernel_size % 2 == 0:
79
+ padding = (kernel_size // 2) * dilation
80
+ self.shift = True
81
+ else:
82
+ padding = ((kernel_size - 1) // 2) * dilation
83
+ self.shift = False
84
+ self.conv = nn.Conv1d(
85
+ hidden_size,
86
+ hidden_size,
87
+ kernel_size,
88
+ padding=padding,
89
+ dilation=dilation)
90
+
91
+ def forward(self, x):
92
+ """Compute convolution.
93
+
94
+ Args:
95
+ x: input embeddings
96
+ Returns:
97
+ conv_output: convolution results
98
+ """
99
+
100
+ if self.shift:
101
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
102
+ else:
103
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
104
+
105
+ class MultiheadAttention(nn.Module):
106
+ """Multi-head self-attention layer."""
107
+
108
+ def __init__(self,
109
+ embed_dim,
110
+ num_heads,
111
+ dropout=0.,
112
+ bias=True,
113
+ v_proj=True,
114
+ out_proj=True,
115
+ relative_bias=True):
116
+ """Initialization.
117
+
118
+ Args:
119
+ embed_dim: dimension of input embeddings
120
+ num_heads: number of self-attention heads
121
+ dropout: dropout rate
122
+ bias: bool, indicate whether include bias for linear transformations
123
+ v_proj: bool, indicate whether project inputs to new values
124
+ out_proj: bool, indicate whether project outputs to new values
125
+ relative_bias: bool, indicate whether use a relative position based
126
+ attention bias
127
+ """
128
+
129
+ super(MultiheadAttention, self).__init__()
130
+ self.embed_dim = embed_dim
131
+
132
+ self.num_heads = num_heads
133
+ self.drop = nn.Dropout(dropout)
134
+ self.head_dim = embed_dim // num_heads
135
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
136
+ "divisible by "
137
+ "num_heads")
138
+
139
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
140
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
141
+ if v_proj:
142
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
143
+ else:
144
+ self.v_proj = nn.Identity()
145
+
146
+ if out_proj:
147
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
148
+ else:
149
+ self.out_proj = nn.Identity()
150
+
151
+ if relative_bias:
152
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
153
+ else:
154
+ self.relative_bias = None
155
+
156
+ self._reset_parameters()
157
+
158
+ def _reset_parameters(self):
159
+ """Initialize attention parameters."""
160
+
161
+ init.xavier_uniform_(self.q_proj.weight)
162
+ init.constant_(self.q_proj.bias, 0.)
163
+
164
+ init.xavier_uniform_(self.k_proj.weight)
165
+ init.constant_(self.k_proj.bias, 0.)
166
+
167
+ if isinstance(self.v_proj, nn.Linear):
168
+ init.xavier_uniform_(self.v_proj.weight)
169
+ init.constant_(self.v_proj.bias, 0.)
170
+
171
+ if isinstance(self.out_proj, nn.Linear):
172
+ init.xavier_uniform_(self.out_proj.weight)
173
+ init.constant_(self.out_proj.bias, 0.)
174
+
175
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
176
+ """Compute multi-head self-attention.
177
+
178
+ Args:
179
+ query: input embeddings
180
+ key_padding_mask: 3D mask that prevents attention to certain positions
181
+ attn_mask: 3D mask that rescale the attention weight at each position
182
+ Returns:
183
+ attn_output: self-attention output
184
+ """
185
+
186
+ length, bsz, embed_dim = query.size()
187
+ assert embed_dim == self.embed_dim
188
+
189
+ head_dim = embed_dim // self.num_heads
190
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
191
+ "divisible by num_heads")
192
+ scaling = float(head_dim)**-0.5
193
+
194
+ q = self.q_proj(query)
195
+ k = self.k_proj(query)
196
+ v = self.v_proj(query)
197
+
198
+ q = q * scaling
199
+
200
+ if attn_mask is not None:
201
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
202
+ query.size(0), query.size(0)]
203
+
204
+ q = q.contiguous().view(length, bsz * self.num_heads,
205
+ head_dim).transpose(0, 1)
206
+ k = k.contiguous().view(length, bsz * self.num_heads,
207
+ head_dim).transpose(0, 1)
208
+ v = v.contiguous().view(length, bsz * self.num_heads,
209
+ head_dim).transpose(0, 1)
210
+
211
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
212
+ assert list(
213
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
214
+
215
+ if self.relative_bias is not None:
216
+ pos = torch.arange(length, device=query.device)
217
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
218
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
219
+ -1)
220
+
221
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
222
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
223
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
224
+ attn_output_weights = attn_output_weights + relative_bias
225
+
226
+ if key_padding_mask is not None:
227
+ attn_output_weights = attn_output_weights + key_padding_mask
228
+
229
+ if attn_mask is None:
230
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
231
+ else:
232
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
233
+
234
+ attn_output_weights = self.drop(attn_output_weights)
235
+
236
+ attn_output = torch.bmm(attn_output_weights, v)
237
+
238
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
239
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
240
+ length, bsz, embed_dim)
241
+ attn_output = self.out_proj(attn_output)
242
+
243
+ return attn_output
244
+
245
+ class TransformerLayer(nn.Module):
246
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
247
+
248
+ def __init__(self,
249
+ d_model,
250
+ nhead,
251
+ dim_feedforward=2048,
252
+ dropout=0.1,
253
+ dropatt=0.1,
254
+ activation="leakyrelu",
255
+ relative_bias=True):
256
+ """Initialization.
257
+
258
+ Args:
259
+ d_model: dimension of inputs
260
+ nhead: number of self-attention heads
261
+ dim_feedforward: dimension of hidden layer in feedforward layer
262
+ dropout: dropout rate
263
+ dropatt: drop attention rate
264
+ activation: activation function
265
+ relative_bias: bool, indicate whether use a relative position based
266
+ attention bias
267
+ """
268
+
269
+ super(TransformerLayer, self).__init__()
270
+
271
+ self.self_attn = MultiheadAttention(
272
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
273
+
274
+ # Implementation of Feedforward model
275
+ self.feedforward = nn.Sequential(
276
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
277
+ _get_activation_fn(activation), nn.Dropout(dropout),
278
+ nn.Linear(dim_feedforward, d_model))
279
+
280
+ self.norm = nn.LayerNorm(d_model)
281
+ self.dropout1 = nn.Dropout(dropout)
282
+ self.dropout2 = nn.Dropout(dropout)
283
+
284
+ self.nhead = nhead
285
+
286
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
287
+ """Pass the input through the encoder layer.
288
+
289
+ Args:
290
+ src: the sequence to the encoder layer (required).
291
+ attn_mask: the mask for the src sequence (optional).
292
+ key_padding_mask: the mask for the src keys per batch (optional).
293
+ Returns:
294
+ src3: the output of transformer layer, share the same shape as src.
295
+ """
296
+ src2 = self.self_attn(
297
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
298
+ src2 = src + self.dropout1(src2)
299
+ src3 = self.feedforward(src2)
300
+ src3 = src2 + self.dropout2(src3)
301
+
302
+ return src3
303
+
304
+ ##########################################
305
+ # Custom Models
306
+ ##########################################
307
+ def cumprod(x, reverse=False, exclusive=False):
308
+ """cumulative product."""
309
+ if reverse:
310
+ x = x.flip([-1])
311
+
312
+ if exclusive:
313
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
314
+
315
+ cx = x.cumprod(-1)
316
+
317
+ if reverse:
318
+ cx = cx.flip([-1])
319
+ return cx
320
+
321
+ def cumsum(x, reverse=False, exclusive=False):
322
+ """cumulative sum."""
323
+ bsz, _, length = x.size()
324
+ device = x.device
325
+ if reverse:
326
+ if exclusive:
327
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
328
+ else:
329
+ w = torch.ones([bsz, length, length], device=device).tril(0)
330
+ cx = torch.bmm(x, w)
331
+ else:
332
+ if exclusive:
333
+ w = torch.ones([bsz, length, length], device=device).triu(1)
334
+ else:
335
+ w = torch.ones([bsz, length, length], device=device).triu(0)
336
+ cx = torch.bmm(x, w)
337
+ return cx
338
+
339
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
340
+ """cumulative min."""
341
+ if reverse:
342
+ if exclusive:
343
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
344
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
345
+ else:
346
+ if exclusive:
347
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
348
+ x = x.cummin(-1)[0]
349
+ return x
350
+
351
+ class Transformer(nn.Module):
352
+ """Transformer model."""
353
+
354
+ def __init__(self,
355
+ hidden_size,
356
+ nlayers,
357
+ ntokens,
358
+ nhead=8,
359
+ dropout=0.1,
360
+ dropatt=0.1,
361
+ relative_bias=True,
362
+ pos_emb=False,
363
+ pad=0):
364
+ """Initialization.
365
+
366
+ Args:
367
+ hidden_size: dimension of inputs and hidden states
368
+ nlayers: number of layers
369
+ ntokens: number of output categories
370
+ nhead: number of self-attention heads
371
+ dropout: dropout rate
372
+ dropatt: drop attention rate
373
+ relative_bias: bool, indicate whether use a relative position based
374
+ attention bias
375
+ pos_emb: bool, indicate whether use a learnable positional embedding
376
+ pad: pad token index
377
+ """
378
+
379
+ super(Transformer, self).__init__()
380
+
381
+ self.drop = nn.Dropout(dropout)
382
+
383
+ self.emb = nn.Embedding(ntokens, hidden_size)
384
+ if pos_emb:
385
+ self.pos_emb = nn.Embedding(500, hidden_size)
386
+
387
+ self.layers = nn.ModuleList([
388
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
389
+ dropatt=dropatt, relative_bias=relative_bias)
390
+ for _ in range(nlayers)])
391
+
392
+ self.norm = nn.LayerNorm(hidden_size)
393
+
394
+ self.output_layer = nn.Linear(hidden_size, ntokens)
395
+ self.output_layer.weight = self.emb.weight
396
+
397
+ self.init_weights()
398
+
399
+ self.nlayers = nlayers
400
+ self.nhead = nhead
401
+ self.ntokens = ntokens
402
+ self.hidden_size = hidden_size
403
+ self.pad = pad
404
+
405
+ def init_weights(self):
406
+ """Initialize token embedding and output bias."""
407
+ initrange = 0.1
408
+ self.emb.weight.data.uniform_(-initrange, initrange)
409
+ if hasattr(self, 'pos_emb'):
410
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
411
+ self.output_layer.bias.data.fill_(0)
412
+
413
+ def visibility(self, x, device):
414
+ """Mask pad tokens."""
415
+ visibility = (x != self.pad).float()
416
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
417
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
418
+ return visibility.log()
419
+
420
+ def encode(self, x, pos):
421
+ """Standard transformer encode process."""
422
+ h = self.emb(x)
423
+ if hasattr(self, 'pos_emb'):
424
+ h = h + self.pos_emb(pos)
425
+ h_list = []
426
+ visibility = self.visibility(x, x.device)
427
+
428
+ for i in range(self.nlayers):
429
+ h_list.append(h)
430
+ h = self.layers[i](
431
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
432
+
433
+ output = h
434
+ h_array = torch.stack(h_list, dim=2)
435
+
436
+ return output, h_array
437
+
438
+ def forward(self, x, pos):
439
+ """Pass the input through the encoder layer.
440
+
441
+ Args:
442
+ x: input tokens (required).
443
+ pos: position for each token (optional).
444
+ Returns:
445
+ output: probability distributions for missing tokens.
446
+ state_dict: parsing results and raw output
447
+ """
448
+
449
+ batch_size, length = x.size()
450
+
451
+ raw_output, _ = self.encode(x, pos)
452
+ raw_output = self.norm(raw_output)
453
+ raw_output = self.drop(raw_output)
454
+
455
+ output = self.output_layer(raw_output)
456
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
457
+
458
+ class StructFormer(Transformer):
459
+ """StructFormer model."""
460
+
461
+ def __init__(self,
462
+ hidden_size,
463
+ n_context_layers,
464
+ nlayers,
465
+ ntokens,
466
+ nhead=8,
467
+ dropout=0.1,
468
+ dropatt=0.1,
469
+ relative_bias=False,
470
+ pos_emb=False,
471
+ pad=0,
472
+ n_parser_layers=4,
473
+ conv_size=9,
474
+ relations=('head', 'child'),
475
+ weight_act='softmax'):
476
+ """Initialization.
477
+
478
+ Args:
479
+ hidden_size: dimension of inputs and hidden states
480
+ nlayers: number of layers
481
+ ntokens: number of output categories
482
+ nhead: number of self-attention heads
483
+ dropout: dropout rate
484
+ dropatt: drop attention rate
485
+ relative_bias: bool, indicate whether use a relative position based
486
+ attention bias
487
+ pos_emb: bool, indicate whether use a learnable positional embedding
488
+ pad: pad token index
489
+ n_parser_layers: number of parsing layers
490
+ conv_size: convolution kernel size for parser
491
+ relations: relations that are used to compute self attention
492
+ weight_act: relations distribution activation function
493
+ """
494
+
495
+ super(StructFormer, self).__init__(
496
+ hidden_size,
497
+ nlayers,
498
+ ntokens,
499
+ nhead=nhead,
500
+ dropout=dropout,
501
+ dropatt=dropatt,
502
+ relative_bias=relative_bias,
503
+ pos_emb=pos_emb,
504
+ pad=pad)
505
+
506
+ if n_context_layers > 0:
507
+ self.context_layers = nn.ModuleList([
508
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
509
+ dropatt=dropatt, relative_bias=relative_bias)
510
+ for _ in range(n_context_layers)])
511
+
512
+ self.parser_layers = nn.ModuleList([
513
+ nn.Sequential(Conv1d(hidden_size, conv_size),
514
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
515
+ nn.Tanh()) for i in range(n_parser_layers)])
516
+
517
+ self.distance_ff = nn.Sequential(
518
+ Conv1d(hidden_size, 2),
519
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
520
+ nn.Linear(hidden_size, 1))
521
+
522
+ self.height_ff = nn.Sequential(
523
+ nn.Linear(hidden_size, hidden_size),
524
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
525
+ nn.Linear(hidden_size, 1))
526
+
527
+ n_rel = len(relations)
528
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
529
+ self._rel_weight.data.normal_(0, 0.1)
530
+
531
+ self._scaler = nn.Parameter(torch.zeros(2))
532
+
533
+ self.n_parse_layers = n_parser_layers
534
+ self.n_context_layers = n_context_layers
535
+ self.weight_act = weight_act
536
+ self.relations = relations
537
+
538
+ @property
539
+ def scaler(self):
540
+ return self._scaler.exp()
541
+
542
+ @property
543
+ def rel_weight(self):
544
+ if self.weight_act == 'sigmoid':
545
+ return torch.sigmoid(self._rel_weight)
546
+ elif self.weight_act == 'softmax':
547
+ return torch.softmax(self._rel_weight, dim=-1)
548
+
549
+ def parse(self, x, pos, embeds=None):
550
+ """Parse input sentence.
551
+
552
+ Args:
553
+ x: input tokens (required).
554
+ pos: position for each token (optional).
555
+ Returns:
556
+ distance: syntactic distance
557
+ height: syntactic height
558
+ """
559
+
560
+ mask = (x != self.pad)
561
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
562
+
563
+
564
+ if embeds is not None:
565
+ h = embeds
566
+ else:
567
+ h = self.emb(x)
568
+
569
+ for i in range(self.n_parse_layers):
570
+ h = h.masked_fill(~mask[:, :, None], 0)
571
+ h = self.parser_layers[i](h)
572
+
573
+ height = self.height_ff(h).squeeze(-1)
574
+ height.masked_fill_(~mask, -1e9)
575
+
576
+ distance = self.distance_ff(h).squeeze(-1)
577
+ distance.masked_fill_(~mask_shifted, 1e9)
578
+
579
+ # Calbrating the distance and height to the same level
580
+ length = distance.size(1)
581
+ height_max = height[:, None, :].expand(-1, length, -1)
582
+ height_max = torch.cummax(
583
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
584
+ dim=-1)[0].triu(0)
585
+
586
+ margin_left = torch.relu(
587
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
588
+ margin_right = torch.relu(distance[:, None, :] - height_max)
589
+ margin = torch.where(margin_left > margin_right, margin_right,
590
+ margin_left).triu(0)
591
+
592
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
593
+ margin.masked_fill_(~margin_mask, 0)
594
+ margin = margin.max()
595
+
596
+ distance = distance - margin
597
+
598
+ return distance, height
599
+
600
+ def compute_block(self, distance, height):
601
+ """Compute constituents from distance and height."""
602
+
603
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
604
+
605
+ gamma = torch.sigmoid(-beta_logits)
606
+ ones = torch.ones_like(gamma)
607
+
608
+ block_mask_left = cummin(
609
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
610
+ block_mask_left = block_mask_left - F.pad(
611
+ block_mask_left[:, :, :-1], (1, 0), value=0)
612
+ block_mask_left.tril_(0)
613
+
614
+ block_mask_right = cummin(
615
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
616
+ block_mask_right = block_mask_right - F.pad(
617
+ block_mask_right[:, :, 1:], (0, 1), value=0)
618
+ block_mask_right.triu_(0)
619
+
620
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
621
+ block = cumsum(block_mask_left).tril(0) + cumsum(
622
+ block_mask_right, reverse=True).triu(1)
623
+
624
+ return block_p, block
625
+
626
+ def compute_head(self, height):
627
+ """Estimate head for each constituent."""
628
+
629
+ _, length = height.size()
630
+ head_logits = height * self.scaler[1]
631
+ index = torch.arange(length, device=height.device)
632
+
633
+ mask = (index[:, None, None] <= index[None, None, :]) * (
634
+ index[None, None, :] <= index[None, :, None])
635
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
636
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
637
+
638
+ head_p = torch.softmax(head_logits, dim=-1)
639
+
640
+ return head_p
641
+
642
+ def generate_mask(self, x, distance, height):
643
+ """Compute head and cibling distribution for each token."""
644
+
645
+ bsz, length = x.size()
646
+
647
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
648
+ eye = eye[None, :, :].expand((bsz, -1, -1))
649
+
650
+ block_p, block = self.compute_block(distance, height)
651
+ head_p = self.compute_head(height)
652
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
653
+ head = head.masked_fill(eye, 0)
654
+ child = head.transpose(1, 2)
655
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
656
+
657
+ rel_list = []
658
+ if 'head' in self.relations:
659
+ rel_list.append(head)
660
+ if 'child' in self.relations:
661
+ rel_list.append(child)
662
+ if 'cibling' in self.relations:
663
+ rel_list.append(cibling)
664
+
665
+ rel = torch.stack(rel_list, dim=1)
666
+
667
+ rel_weight = self.rel_weight
668
+
669
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
670
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
671
+
672
+ return att_mask, cibling, head, block
673
+
674
+ def encode(self, x, pos, att_mask=None, context_layers=False):
675
+ """Structformer encoding process."""
676
+
677
+ if context_layers:
678
+ """Standard transformer encode process."""
679
+ h = self.emb(x)
680
+ if hasattr(self, 'pos_emb'):
681
+ h = h + self.pos_emb(pos)
682
+ h_list = []
683
+ visibility = self.visibility(x, x.device)
684
+ for i in range(self.n_context_layers):
685
+ h_list.append(h)
686
+ h = self.context_layers[i](
687
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
688
+
689
+ output = h
690
+ h_array = torch.stack(h_list, dim=2)
691
+ return output
692
+
693
+ else:
694
+ visibility = self.visibility(x, x.device)
695
+ h = self.emb(x)
696
+ if hasattr(self, 'pos_emb'):
697
+ assert pos.max() < 500
698
+ h = h + self.pos_emb(pos)
699
+ for i in range(self.nlayers):
700
+ h = self.layers[i](
701
+ h.transpose(0, 1), attn_mask=att_mask[i],
702
+ key_padding_mask=visibility).transpose(0, 1)
703
+ return h
704
+
705
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
706
+
707
+ x = input_ids
708
+ batch_size, length = x.size()
709
+
710
+ if position_ids is None:
711
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
712
+
713
+ context_layers_output = None
714
+ if self.n_context_layers > 0:
715
+ context_layers_output = self.encode(x, pos, context_layers=True)
716
+
717
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
718
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
719
+
720
+ raw_output = self.encode(x, pos, att_mask)
721
+ raw_output = self.norm(raw_output)
722
+ raw_output = self.drop(raw_output)
723
+
724
+ output = self.output_layer(raw_output)
725
+
726
+ loss = None
727
+ if labels is not None:
728
+ loss_fct = nn.CrossEntropyLoss()
729
+ loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1))
730
+
731
+ return MaskedLMOutput(
732
+ loss=loss, # shape: 1
733
+ logits=output, # shape: (batch_size * length, ntokens)
734
+ hidden_states=None,
735
+ attentions=None,
736
+ )
737
+
738
+ ##########################################
739
+ # HuggingFace Model
740
+ ##########################################
741
+ class StructformerModel(PreTrainedModel):
742
+ config_class = StructformerConfig
743
+
744
+ def __init__(self, config):
745
+ super().__init__(config)
746
+ self.model = StructFormer(
747
+ hidden_size=config.hidden_size,
748
+ n_context_layers=config.n_context_layers,
749
+ nlayers=config.nlayers,
750
+ ntokens=config.ntokens,
751
+ nhead=config.nhead,
752
+ dropout=config.dropout,
753
+ dropatt=config.dropatt,
754
+ relative_bias=config.relative_bias,
755
+ pos_emb=config.pos_emb,
756
+ pad=config.pad,
757
+ n_parser_layers=config.n_parser_layers,
758
+ conv_size=config.conv_size,
759
+ relations=config.relations,
760
+ weight_act=config.weight_act
761
+ )
762
+
763
+ def forward(self, input_ids, labels=None, **kwargs):
764
+ return self.model(input_ids, labels=labels, **kwargs)