omarmomen commited on
Commit
4ef095f
1 Parent(s): 257e44a
Files changed (3) hide show
  1. config.json +30 -0
  2. pytorch_model.bin +3 -0
  3. structformer.py +616 -0
config.json ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "StructFormerModel"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "structformer.StructFormerConfig",
7
+ "AutoModelForMaskedLM": "structformer.StructFormerModel"
8
+ },
9
+ "conv_size": 9,
10
+ "dropatt": 0.1,
11
+ "dropout": 0.1,
12
+ "hidden_dropout_prob": 0.1,
13
+ "hidden_size": 512,
14
+ "initializer_range": 0.02,
15
+ "model_type": "structformer",
16
+ "n_parser_layers": 3,
17
+ "nhead": 8,
18
+ "nlayers": 8,
19
+ "ntokens": 16000,
20
+ "pad": 1,
21
+ "pos_emb": true,
22
+ "relations": [
23
+ "head",
24
+ "child"
25
+ ],
26
+ "relative_bias": false,
27
+ "torch_dtype": "float32",
28
+ "transformers_version": "4.18.0",
29
+ "weight_act": "softmax"
30
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f3bbbcda010c3d3c6f67547cda9045d4fe77732ffbb47d1cbc7e596a6d5e1e5d
3
+ size 166255895
structformer.py ADDED
@@ -0,0 +1,616 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # coding=utf-8
2
+ # Copyright 2023 The Google Research Authors.
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """StructFormer and transformer model."""
17
+
18
+ import torch
19
+ import torch.nn as nn
20
+ import torch.nn.functional as F
21
+
22
+ import layers
23
+
24
+ from transformers import PretrainedConfig, PreTrainedModel
25
+ from transformers.modeling_outputs import MaskedLMOutput, SequenceClassifierOutput
26
+
27
+
28
+ def cumprod(x, reverse=False, exclusive=False):
29
+ """cumulative product."""
30
+ if reverse:
31
+ x = x.flip([-1])
32
+
33
+ if exclusive:
34
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
35
+
36
+ cx = x.cumprod(-1)
37
+
38
+ if reverse:
39
+ cx = cx.flip([-1])
40
+ return cx
41
+
42
+ def cumsum(x, reverse=False, exclusive=False):
43
+ """cumulative sum."""
44
+ bsz, _, length = x.size()
45
+ device = x.device
46
+ if reverse:
47
+ if exclusive:
48
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
49
+ else:
50
+ w = torch.ones([bsz, length, length], device=device).tril(0)
51
+ cx = torch.bmm(x, w)
52
+ else:
53
+ if exclusive:
54
+ w = torch.ones([bsz, length, length], device=device).triu(1)
55
+ else:
56
+ w = torch.ones([bsz, length, length], device=device).triu(0)
57
+ cx = torch.bmm(x, w)
58
+ return cx
59
+
60
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
61
+ """cumulative min."""
62
+ if reverse:
63
+ if exclusive:
64
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
65
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
66
+ else:
67
+ if exclusive:
68
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
69
+ x = x.cummin(-1)[0]
70
+ return x
71
+
72
+ class Transformer(nn.Module):
73
+ """Transformer model."""
74
+
75
+ def __init__(self,
76
+ hidden_size,
77
+ nlayers,
78
+ ntokens,
79
+ nhead=8,
80
+ dropout=0.1,
81
+ dropatt=0.1,
82
+ relative_bias=True,
83
+ pos_emb=False,
84
+ pad=0):
85
+ """Initialization.
86
+
87
+ Args:
88
+ hidden_size: dimension of inputs and hidden states
89
+ nlayers: number of layers
90
+ ntokens: number of output categories
91
+ nhead: number of self-attention heads
92
+ dropout: dropout rate
93
+ dropatt: drop attention rate
94
+ relative_bias: bool, indicate whether use a relative position based
95
+ attention bias
96
+ pos_emb: bool, indicate whether use a learnable positional embedding
97
+ pad: pad token index
98
+ """
99
+
100
+ super(Transformer, self).__init__()
101
+
102
+ self.drop = nn.Dropout(dropout)
103
+
104
+ self.emb = nn.Embedding(ntokens, hidden_size)
105
+ if pos_emb:
106
+ self.pos_emb = nn.Embedding(500, hidden_size)
107
+
108
+ self.layers = nn.ModuleList([
109
+ layers.TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
110
+ dropatt=dropatt, relative_bias=relative_bias)
111
+ for _ in range(nlayers)])
112
+
113
+ self.norm = nn.LayerNorm(hidden_size)
114
+
115
+ self.output_layer = nn.Linear(hidden_size, ntokens)
116
+ self.output_layer.weight = self.emb.weight
117
+
118
+ self.init_weights()
119
+
120
+ self.nlayers = nlayers
121
+ self.nhead = nhead
122
+ self.ntokens = ntokens
123
+ self.hidden_size = hidden_size
124
+ self.pad = pad
125
+
126
+ def init_weights(self):
127
+ """Initialize token embedding and output bias."""
128
+ initrange = 0.1
129
+ self.emb.weight.data.uniform_(-initrange, initrange)
130
+ if hasattr(self, 'pos_emb'):
131
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
132
+ self.output_layer.bias.data.fill_(0)
133
+
134
+ def visibility(self, x, device):
135
+ """Mask pad tokens."""
136
+ visibility = (x != self.pad).float()
137
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
138
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
139
+ return visibility.log()
140
+
141
+ def encode(self, x, pos):
142
+ """Standard transformer encode process."""
143
+ h = self.emb(x)
144
+ if hasattr(self, 'pos_emb'):
145
+ h = h + self.pos_emb(pos)
146
+ h_list = []
147
+ visibility = self.visibility(x, x.device)
148
+
149
+ for i in range(self.nlayers):
150
+ h_list.append(h)
151
+ h = self.layers[i](
152
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
153
+
154
+ output = h
155
+ h_array = torch.stack(h_list, dim=2)
156
+
157
+ return output, h_array
158
+
159
+ def forward(self, x, pos):
160
+ """Pass the input through the encoder layer.
161
+
162
+ Args:
163
+ x: input tokens (required).
164
+ pos: position for each token (optional).
165
+ Returns:
166
+ output: probability distributions for missing tokens.
167
+ state_dict: parsing results and raw output
168
+ """
169
+
170
+ batch_size, length = x.size()
171
+
172
+ raw_output, _ = self.encode(x, pos)
173
+ raw_output = self.norm(raw_output)
174
+ raw_output = self.drop(raw_output)
175
+
176
+ output = self.output_layer(raw_output)
177
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
178
+
179
+ class StructFormer(Transformer):
180
+ """StructFormer model."""
181
+
182
+ def __init__(self,
183
+ hidden_size,
184
+ nlayers,
185
+ ntokens,
186
+ nhead=8,
187
+ dropout=0.1,
188
+ dropatt=0.1,
189
+ relative_bias=False,
190
+ pos_emb=False,
191
+ pad=0,
192
+ n_parser_layers=4,
193
+ conv_size=9,
194
+ relations=('head', 'child'),
195
+ weight_act='softmax'):
196
+ """Initialization.
197
+
198
+ Args:
199
+ hidden_size: dimension of inputs and hidden states
200
+ nlayers: number of layers
201
+ ntokens: number of output categories
202
+ nhead: number of self-attention heads
203
+ dropout: dropout rate
204
+ dropatt: drop attention rate
205
+ relative_bias: bool, indicate whether use a relative position based
206
+ attention bias
207
+ pos_emb: bool, indicate whether use a learnable positional embedding
208
+ pad: pad token index
209
+ n_parser_layers: number of parsing layers
210
+ conv_size: convolution kernel size for parser
211
+ relations: relations that are used to compute self attention
212
+ weight_act: relations distribution activation function
213
+ """
214
+
215
+ super(StructFormer, self).__init__(
216
+ hidden_size,
217
+ nlayers,
218
+ ntokens,
219
+ nhead=nhead,
220
+ dropout=dropout,
221
+ dropatt=dropatt,
222
+ relative_bias=relative_bias,
223
+ pos_emb=pos_emb,
224
+ pad=pad)
225
+
226
+ self.parser_layers = nn.ModuleList([
227
+ nn.Sequential(layers.Conv1d(hidden_size, conv_size),
228
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
229
+ nn.Tanh()) for i in range(n_parser_layers)])
230
+
231
+ self.distance_ff = nn.Sequential(
232
+ layers.Conv1d(hidden_size, 2),
233
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
234
+ nn.Linear(hidden_size, 1))
235
+
236
+ self.height_ff = nn.Sequential(
237
+ nn.Linear(hidden_size, hidden_size),
238
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
239
+ nn.Linear(hidden_size, 1))
240
+
241
+ n_rel = len(relations)
242
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
243
+ self._rel_weight.data.normal_(0, 0.1)
244
+
245
+ self._scaler = nn.Parameter(torch.zeros(2))
246
+
247
+ self.n_parse_layers = n_parser_layers
248
+ self.weight_act = weight_act
249
+ self.relations = relations
250
+
251
+ @property
252
+ def scaler(self):
253
+ return self._scaler.exp()
254
+
255
+ @property
256
+ def rel_weight(self):
257
+ if self.weight_act == 'sigmoid':
258
+ return torch.sigmoid(self._rel_weight)
259
+ elif self.weight_act == 'softmax':
260
+ return torch.softmax(self._rel_weight, dim=-1)
261
+
262
+ def parse(self, x, pos):
263
+ """Parse input sentence.
264
+
265
+ Args:
266
+ x: input tokens (required).
267
+ pos: position for each token (optional).
268
+ Returns:
269
+ distance: syntactic distance
270
+ height: syntactic height
271
+ """
272
+
273
+ mask = (x != self.pad)
274
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
275
+
276
+ h = self.emb(x)
277
+ for i in range(self.n_parse_layers):
278
+ h = h.masked_fill(~mask[:, :, None], 0)
279
+ h = self.parser_layers[i](h)
280
+
281
+ height = self.height_ff(h).squeeze(-1)
282
+ height.masked_fill_(~mask, -1e9)
283
+
284
+ distance = self.distance_ff(h).squeeze(-1)
285
+ distance.masked_fill_(~mask_shifted, 1e9)
286
+
287
+ # Calbrating the distance and height to the same level
288
+ length = distance.size(1)
289
+ height_max = height[:, None, :].expand(-1, length, -1)
290
+ height_max = torch.cummax(
291
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
292
+ dim=-1)[0].triu(0)
293
+
294
+ margin_left = torch.relu(
295
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
296
+ margin_right = torch.relu(distance[:, None, :] - height_max)
297
+ margin = torch.where(margin_left > margin_right, margin_right,
298
+ margin_left).triu(0)
299
+
300
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
301
+ margin.masked_fill_(~margin_mask, 0)
302
+ margin = margin.max()
303
+
304
+ distance = distance - margin
305
+
306
+ return distance, height
307
+
308
+ def compute_block(self, distance, height):
309
+ """Compute constituents from distance and height."""
310
+
311
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
312
+
313
+ gamma = torch.sigmoid(-beta_logits)
314
+ ones = torch.ones_like(gamma)
315
+
316
+ block_mask_left = cummin(
317
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
318
+ block_mask_left = block_mask_left - F.pad(
319
+ block_mask_left[:, :, :-1], (1, 0), value=0)
320
+ block_mask_left.tril_(0)
321
+
322
+ block_mask_right = cummin(
323
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
324
+ block_mask_right = block_mask_right - F.pad(
325
+ block_mask_right[:, :, 1:], (0, 1), value=0)
326
+ block_mask_right.triu_(0)
327
+
328
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
329
+ block = cumsum(block_mask_left).tril(0) + cumsum(
330
+ block_mask_right, reverse=True).triu(1)
331
+
332
+ return block_p, block
333
+
334
+ def compute_head(self, height):
335
+ """Estimate head for each constituent."""
336
+
337
+ _, length = height.size()
338
+ head_logits = height * self.scaler[1]
339
+ index = torch.arange(length, device=height.device)
340
+
341
+ mask = (index[:, None, None] <= index[None, None, :]) * (
342
+ index[None, None, :] <= index[None, :, None])
343
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
344
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
345
+
346
+ head_p = torch.softmax(head_logits, dim=-1)
347
+
348
+ return head_p
349
+
350
+ def generate_mask(self, x, distance, height):
351
+ """Compute head and cibling distribution for each token."""
352
+
353
+ bsz, length = x.size()
354
+
355
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
356
+ eye = eye[None, :, :].expand((bsz, -1, -1))
357
+
358
+ block_p, block = self.compute_block(distance, height)
359
+ head_p = self.compute_head(height)
360
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
361
+ head = head.masked_fill(eye, 0)
362
+ child = head.transpose(1, 2)
363
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
364
+
365
+ rel_list = []
366
+ if 'head' in self.relations:
367
+ rel_list.append(head)
368
+ if 'child' in self.relations:
369
+ rel_list.append(child)
370
+ if 'cibling' in self.relations:
371
+ rel_list.append(cibling)
372
+
373
+ rel = torch.stack(rel_list, dim=1)
374
+
375
+ rel_weight = self.rel_weight
376
+
377
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
378
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
379
+
380
+ return att_mask, cibling, head, block
381
+
382
+ def encode(self, x, pos, att_mask):
383
+ """Structformer encoding process."""
384
+
385
+ visibility = self.visibility(x, x.device)
386
+ h = self.emb(x)
387
+ if hasattr(self, 'pos_emb'):
388
+ assert pos.max() < 500
389
+ h = h + self.pos_emb(pos)
390
+ for i in range(self.nlayers):
391
+ h = self.layers[i](
392
+ h.transpose(0, 1), attn_mask=att_mask[i],
393
+ key_padding_mask=visibility).transpose(0, 1)
394
+ return h
395
+
396
+ def forward(self, x, pos):
397
+ """Pass the input through the encoder layer.
398
+
399
+ Args:
400
+ x: input tokens (required).
401
+ pos: position for each token (optional).
402
+ Returns:
403
+ output: probability distributions for missing tokens.
404
+ state_dict: parsing results and raw output
405
+ """
406
+
407
+ batch_size, length = x.size()
408
+
409
+ distance, height = self.parse(x, pos)
410
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
411
+
412
+ raw_output = self.encode(x, pos, att_mask)
413
+ raw_output = self.norm(raw_output)
414
+ raw_output = self.drop(raw_output)
415
+
416
+ output = self.output_layer(raw_output)
417
+
418
+ return output.view(batch_size * length, -1), \
419
+ {'raw_output': raw_output, 'distance': distance, 'height': height,
420
+ 'cibling': cibling, 'head': head, 'block': block}
421
+
422
+
423
+ ##########################################
424
+ # Clasication Head For BabyLM Evaluation Tasks
425
+ ##########################################
426
+ class ClassificationHead(nn.Module):
427
+ """Head for sentence-level classification tasks."""
428
+ def __init__(self, config):
429
+ super(ClassificationHead, self).__init__()
430
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
431
+ self.dropout = nn.Dropout(config.hidden_dropout_prob)
432
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
433
+
434
+ def forward(self, features, **kwargs):
435
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
436
+ x = self.dropout(x)
437
+ x = self.dense(x)
438
+ x = torch.tanh(x)
439
+ x = self.dropout(x)
440
+ x = self.out_proj(x)
441
+ return x
442
+
443
+ ##########################################
444
+ # HuggingFace Config
445
+ ##########################################
446
+ class StructFormerConfig(PretrainedConfig):
447
+ model_type = "structformer"
448
+
449
+ def __init__(
450
+ self,
451
+ hidden_size=512,
452
+ nlayers=8,
453
+ ntokens=10_000,
454
+ nhead=8,
455
+ dropout=0.1,
456
+ dropatt=0.1,
457
+ relative_bias=False,
458
+ pos_emb=False,
459
+ pad=0,
460
+ n_parser_layers=4,
461
+ conv_size=9,
462
+ relations=('head', 'child'),
463
+ weight_act='softmax',
464
+ num_labels=1,
465
+ hidden_dropout_prob=0.1,
466
+ initializer_range=0.02,
467
+ **kwargs,
468
+ ):
469
+ self.hidden_size = hidden_size
470
+ self.nlayers = nlayers
471
+ self.ntokens = ntokens
472
+ self.nhead = nhead
473
+ self.dropout = dropout
474
+ self.dropatt = dropatt
475
+ self.relative_bias = relative_bias
476
+ self.pos_emb = pos_emb
477
+ self.pad = pad
478
+ self.n_parser_layers = n_parser_layers
479
+ self.conv_size = conv_size
480
+ self.relations = relations
481
+ self.weight_act = weight_act
482
+ self.num_labels = num_labels
483
+ self.hidden_dropout_prob = hidden_dropout_prob
484
+ self.initializer_range=initializer_range
485
+ super().__init__(**kwargs)
486
+
487
+ ##########################################
488
+ # HuggingFace Model
489
+ ##########################################
490
+ class StructFormerModel(PreTrainedModel):
491
+ config_class = StructFormerConfig
492
+
493
+ def __init__(self, config):
494
+ super().__init__(config)
495
+ self.model = StructFormer(
496
+ hidden_size=config.hidden_size,
497
+ nlayers=config.nlayers,
498
+ ntokens=config.ntokens,
499
+ nhead=config.nhead,
500
+ dropout=config.dropout,
501
+ dropatt=config.dropatt,
502
+ relative_bias=config.relative_bias,
503
+ pos_emb=config.pos_emb,
504
+ pad=config.pad,
505
+ n_parser_layers=config.n_parser_layers,
506
+ conv_size=config.conv_size,
507
+ relations=config.relations,
508
+ weight_act=config.weight_act
509
+ )
510
+ self.config = config
511
+
512
+ def parse(self, input_ids, **kwargs):
513
+ x = input_ids
514
+ batch_size, length = x.size()
515
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
516
+
517
+ sf_output = self.model(x, pos)
518
+
519
+ return sf_output[1]
520
+
521
+ def forward(self, input_ids, labels=None, **kwargs):
522
+ x = input_ids
523
+ batch_size, length = x.size()
524
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
525
+
526
+ sf_output = self.model(x, pos)
527
+
528
+ loss = None
529
+ if labels is not None:
530
+ loss_fct = nn.CrossEntropyLoss()
531
+ loss = loss_fct(sf_output[0], labels.reshape(-1))
532
+
533
+ return MaskedLMOutput(
534
+ loss=loss, # shape: 1
535
+ logits=sf_output[0].view(batch_size, length, -1), # shape: (batch_size, length, ntokens)
536
+ hidden_states=None,
537
+ attentions=None
538
+ )
539
+
540
+ class StructFormerModelForSequenceClassification(PreTrainedModel):
541
+ config_class = StructFormerConfig
542
+
543
+ def __init__(self, config):
544
+ super().__init__(config)
545
+ self.model = StructFormer(
546
+ hidden_size=config.hidden_size,
547
+ nlayers=config.nlayers,
548
+ ntokens=config.ntokens,
549
+ nhead=config.nhead,
550
+ dropout=config.dropout,
551
+ dropatt=config.dropatt,
552
+ relative_bias=config.relative_bias,
553
+ pos_emb=config.pos_emb,
554
+ pad=config.pad,
555
+ n_parser_layers=config.n_parser_layers,
556
+ conv_size=config.conv_size,
557
+ relations=config.relations,
558
+ weight_act=config.weight_act
559
+ )
560
+ self.config = config
561
+ self.model.classifier = ClassificationHead(config)
562
+
563
+ def _init_weights(self, module):
564
+ """Initialize the weights"""
565
+ if isinstance(module, nn.Linear):
566
+ # Slightly different from the TF version which uses truncated_normal for initialization
567
+ # cf https://github.com/pytorch/pytorch/pull/5617
568
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
569
+ if module.bias is not None:
570
+ module.bias.data.zero_()
571
+ elif isinstance(module, nn.Embedding):
572
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
573
+ if module.padding_idx is not None:
574
+ module.weight.data[module.padding_idx].zero_()
575
+ elif isinstance(module, nn.LayerNorm):
576
+ if module.bias is not None:
577
+ module.bias.data.zero_()
578
+ module.weight.data.fill_(1.0)
579
+
580
+ def forward(self, input_ids, labels=None, **kwargs):
581
+ x = input_ids
582
+ batch_size, length = x.size()
583
+ pos = kwargs['position_ids'] if 'position_ids' in kwargs.keys() else torch.arange(length, device=x.device).expand(batch_size, length)
584
+
585
+ sf_output = self.model(x, pos)
586
+
587
+ logits = self.model.classifier(sf_output[1]['raw_output'])
588
+ loss = None
589
+ if labels is not None:
590
+ if self.config.problem_type is None:
591
+ if self.num_labels == 1:
592
+ self.config.problem_type = "regression"
593
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
594
+ self.config.problem_type = "single_label_classification"
595
+ else:
596
+ self.config.problem_type = "multi_label_classification"
597
+
598
+ if self.config.problem_type == "regression":
599
+ loss_fct = nn.MSELoss()
600
+ if self.num_labels == 1:
601
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
602
+ else:
603
+ loss = loss_fct(logits, labels)
604
+ elif self.config.problem_type == "single_label_classification":
605
+ loss_fct = nn.CrossEntropyLoss()
606
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
607
+ elif self.config.problem_type == "multi_label_classification":
608
+ loss_fct = nn.BCEWithLogitsLoss()
609
+ loss = loss_fct(logits, labels)
610
+
611
+ return SequenceClassifierOutput(
612
+ loss=loss,
613
+ logits=logits,
614
+ hidden_states=None,
615
+ attentions=None,
616
+ )