Omar commited on
Commit
abe8798
1 Parent(s): 7227b33
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. finetune/boolq/all_results.json +16 -0
  2. finetune/boolq/config.json +57 -0
  3. finetune/boolq/eval_results.json +11 -0
  4. finetune/boolq/merges.txt +0 -0
  5. finetune/boolq/predict_results.txt +724 -0
  6. finetune/boolq/pytorch_model.bin +3 -0
  7. finetune/boolq/special_tokens_map.json +15 -0
  8. finetune/boolq/structformer_as_hf.py +1123 -0
  9. finetune/boolq/tokenizer_config.json +65 -0
  10. finetune/boolq/train_results.json +8 -0
  11. finetune/boolq/trainer_state.json +25 -0
  12. finetune/boolq/training_args.bin +3 -0
  13. finetune/boolq/vocab.json +0 -0
  14. finetune/cola/all_results.json +16 -0
  15. finetune/cola/checkpoint-400/config.json +57 -0
  16. finetune/cola/checkpoint-400/merges.txt +0 -0
  17. finetune/cola/checkpoint-400/optimizer.pt +3 -0
  18. finetune/cola/checkpoint-400/pytorch_model.bin +3 -0
  19. finetune/cola/checkpoint-400/rng_state.pth +3 -0
  20. finetune/cola/checkpoint-400/scheduler.pt +3 -0
  21. finetune/cola/checkpoint-400/special_tokens_map.json +15 -0
  22. finetune/cola/checkpoint-400/structformer_as_hf.py +1123 -0
  23. finetune/cola/checkpoint-400/tokenizer_config.json +65 -0
  24. finetune/cola/checkpoint-400/trainer_state.json +27 -0
  25. finetune/cola/checkpoint-400/training_args.bin +3 -0
  26. finetune/cola/checkpoint-400/vocab.json +0 -0
  27. finetune/cola/config.json +57 -0
  28. finetune/cola/eval_results.json +11 -0
  29. finetune/cola/merges.txt +0 -0
  30. finetune/cola/predict_results.txt +1020 -0
  31. finetune/cola/pytorch_model.bin +3 -0
  32. finetune/cola/special_tokens_map.json +15 -0
  33. finetune/cola/structformer_as_hf.py +1123 -0
  34. finetune/cola/tokenizer_config.json +65 -0
  35. finetune/cola/train_results.json +8 -0
  36. finetune/cola/trainer_state.json +42 -0
  37. finetune/cola/training_args.bin +3 -0
  38. finetune/cola/vocab.json +0 -0
  39. finetune/control_raising_control/all_results.json +16 -0
  40. finetune/control_raising_control/checkpoint-400/config.json +57 -0
  41. finetune/control_raising_control/checkpoint-400/merges.txt +0 -0
  42. finetune/control_raising_control/checkpoint-400/optimizer.pt +3 -0
  43. finetune/control_raising_control/checkpoint-400/pytorch_model.bin +3 -0
  44. finetune/control_raising_control/checkpoint-400/rng_state.pth +3 -0
  45. finetune/control_raising_control/checkpoint-400/scheduler.pt +3 -0
  46. finetune/control_raising_control/checkpoint-400/special_tokens_map.json +15 -0
  47. finetune/control_raising_control/checkpoint-400/structformer_as_hf.py +1123 -0
  48. finetune/control_raising_control/checkpoint-400/tokenizer_config.json +65 -0
  49. finetune/control_raising_control/checkpoint-400/trainer_state.json +27 -0
  50. finetune/control_raising_control/checkpoint-400/training_args.bin +3 -0
finetune/boolq/all_results.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 10.0,
3
+ "eval_accuracy": 0.6071922779083252,
4
+ "eval_f1": 0.7029288702928871,
5
+ "eval_loss": 0.6942005157470703,
6
+ "eval_mcc": 0.14370055324223993,
7
+ "eval_runtime": 1.3999,
8
+ "eval_samples": 723,
9
+ "eval_samples_per_second": 516.453,
10
+ "eval_steps_per_second": 65.003,
11
+ "train_loss": 0.5978388892279731,
12
+ "train_runtime": 99.3528,
13
+ "train_samples": 2072,
14
+ "train_samples_per_second": 208.55,
15
+ "train_steps_per_second": 1.812
16
+ }
finetune/boolq/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "omarmomen/structformer_s1_final_with_pos",
3
+ "architectures": [
4
+ "StructformerModelForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "structformer_as_hf.StructformerConfig",
9
+ "AutoModelForMaskedLM": "structformer_as_hf.StructformerModel",
10
+ "AutoModelForSequenceClassification": "structformer_as_hf.StructformerModelForSequenceClassification"
11
+ },
12
+ "bos_token_id": 0,
13
+ "classifier_dropout": null,
14
+ "conv_size": 9,
15
+ "dropatt": 0.1,
16
+ "dropout": 0.1,
17
+ "eos_token_id": 2,
18
+ "hidden_act": "gelu",
19
+ "hidden_dropout_prob": 0.1,
20
+ "hidden_size": 768,
21
+ "id2label": {
22
+ "0": 0,
23
+ "1": 1
24
+ },
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 3072,
27
+ "label2id": {
28
+ "0": 0,
29
+ "1": 1
30
+ },
31
+ "layer_norm_eps": 1e-05,
32
+ "max_position_embeddings": 514,
33
+ "model_type": "structformer",
34
+ "n_context_layers": 0,
35
+ "n_parser_layers": 4,
36
+ "nhead": 12,
37
+ "nlayers": 12,
38
+ "ntokens": 32000,
39
+ "num_attention_heads": 8,
40
+ "num_hidden_layers": 8,
41
+ "pad": 0,
42
+ "pad_token_id": 1,
43
+ "pos_emb": true,
44
+ "position_embedding_type": "absolute",
45
+ "problem_type": "single_label_classification",
46
+ "relations": [
47
+ "head",
48
+ "child"
49
+ ],
50
+ "relative_bias": false,
51
+ "torch_dtype": "float32",
52
+ "transformers_version": "4.26.1",
53
+ "type_vocab_size": 1,
54
+ "use_cache": true,
55
+ "vocab_size": 32000,
56
+ "weight_act": "softmax"
57
+ }
finetune/boolq/eval_results.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 10.0,
3
+ "eval_accuracy": 0.6071922779083252,
4
+ "eval_f1": 0.7029288702928871,
5
+ "eval_loss": 0.6942005157470703,
6
+ "eval_mcc": 0.14370055324223993,
7
+ "eval_runtime": 1.3999,
8
+ "eval_samples": 723,
9
+ "eval_samples_per_second": 516.453,
10
+ "eval_steps_per_second": 65.003
11
+ }
finetune/boolq/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
finetune/boolq/predict_results.txt ADDED
@@ -0,0 +1,724 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ index prediction
2
+ 0 1
3
+ 1 1
4
+ 2 1
5
+ 3 1
6
+ 4 1
7
+ 5 1
8
+ 6 0
9
+ 7 1
10
+ 8 1
11
+ 9 1
12
+ 10 1
13
+ 11 1
14
+ 12 1
15
+ 13 1
16
+ 14 1
17
+ 15 0
18
+ 16 1
19
+ 17 1
20
+ 18 1
21
+ 19 1
22
+ 20 1
23
+ 21 0
24
+ 22 1
25
+ 23 1
26
+ 24 1
27
+ 25 1
28
+ 26 0
29
+ 27 1
30
+ 28 0
31
+ 29 1
32
+ 30 1
33
+ 31 1
34
+ 32 1
35
+ 33 1
36
+ 34 0
37
+ 35 1
38
+ 36 1
39
+ 37 1
40
+ 38 1
41
+ 39 1
42
+ 40 0
43
+ 41 1
44
+ 42 1
45
+ 43 1
46
+ 44 1
47
+ 45 1
48
+ 46 1
49
+ 47 1
50
+ 48 1
51
+ 49 1
52
+ 50 1
53
+ 51 1
54
+ 52 0
55
+ 53 1
56
+ 54 1
57
+ 55 1
58
+ 56 1
59
+ 57 1
60
+ 58 1
61
+ 59 1
62
+ 60 1
63
+ 61 0
64
+ 62 1
65
+ 63 1
66
+ 64 1
67
+ 65 1
68
+ 66 0
69
+ 67 1
70
+ 68 1
71
+ 69 1
72
+ 70 1
73
+ 71 0
74
+ 72 1
75
+ 73 0
76
+ 74 0
77
+ 75 1
78
+ 76 1
79
+ 77 0
80
+ 78 0
81
+ 79 0
82
+ 80 1
83
+ 81 0
84
+ 82 0
85
+ 83 1
86
+ 84 1
87
+ 85 1
88
+ 86 0
89
+ 87 1
90
+ 88 1
91
+ 89 1
92
+ 90 1
93
+ 91 1
94
+ 92 1
95
+ 93 1
96
+ 94 0
97
+ 95 1
98
+ 96 0
99
+ 97 1
100
+ 98 0
101
+ 99 1
102
+ 100 1
103
+ 101 0
104
+ 102 1
105
+ 103 0
106
+ 104 1
107
+ 105 1
108
+ 106 1
109
+ 107 0
110
+ 108 1
111
+ 109 0
112
+ 110 1
113
+ 111 1
114
+ 112 1
115
+ 113 1
116
+ 114 1
117
+ 115 1
118
+ 116 1
119
+ 117 1
120
+ 118 1
121
+ 119 1
122
+ 120 0
123
+ 121 1
124
+ 122 1
125
+ 123 0
126
+ 124 1
127
+ 125 1
128
+ 126 1
129
+ 127 1
130
+ 128 1
131
+ 129 0
132
+ 130 1
133
+ 131 0
134
+ 132 1
135
+ 133 1
136
+ 134 1
137
+ 135 1
138
+ 136 1
139
+ 137 0
140
+ 138 1
141
+ 139 1
142
+ 140 1
143
+ 141 1
144
+ 142 1
145
+ 143 1
146
+ 144 1
147
+ 145 1
148
+ 146 0
149
+ 147 1
150
+ 148 0
151
+ 149 0
152
+ 150 1
153
+ 151 1
154
+ 152 1
155
+ 153 1
156
+ 154 1
157
+ 155 1
158
+ 156 0
159
+ 157 1
160
+ 158 1
161
+ 159 0
162
+ 160 1
163
+ 161 1
164
+ 162 1
165
+ 163 0
166
+ 164 1
167
+ 165 0
168
+ 166 1
169
+ 167 1
170
+ 168 0
171
+ 169 0
172
+ 170 1
173
+ 171 0
174
+ 172 1
175
+ 173 1
176
+ 174 1
177
+ 175 0
178
+ 176 0
179
+ 177 1
180
+ 178 0
181
+ 179 1
182
+ 180 1
183
+ 181 1
184
+ 182 1
185
+ 183 0
186
+ 184 0
187
+ 185 1
188
+ 186 0
189
+ 187 0
190
+ 188 1
191
+ 189 1
192
+ 190 1
193
+ 191 1
194
+ 192 1
195
+ 193 0
196
+ 194 1
197
+ 195 0
198
+ 196 1
199
+ 197 1
200
+ 198 0
201
+ 199 1
202
+ 200 1
203
+ 201 1
204
+ 202 1
205
+ 203 1
206
+ 204 1
207
+ 205 0
208
+ 206 1
209
+ 207 1
210
+ 208 1
211
+ 209 1
212
+ 210 1
213
+ 211 1
214
+ 212 1
215
+ 213 0
216
+ 214 0
217
+ 215 1
218
+ 216 1
219
+ 217 1
220
+ 218 1
221
+ 219 0
222
+ 220 1
223
+ 221 1
224
+ 222 1
225
+ 223 0
226
+ 224 0
227
+ 225 1
228
+ 226 1
229
+ 227 0
230
+ 228 1
231
+ 229 1
232
+ 230 1
233
+ 231 1
234
+ 232 1
235
+ 233 0
236
+ 234 1
237
+ 235 0
238
+ 236 0
239
+ 237 1
240
+ 238 1
241
+ 239 1
242
+ 240 1
243
+ 241 0
244
+ 242 1
245
+ 243 1
246
+ 244 0
247
+ 245 1
248
+ 246 1
249
+ 247 0
250
+ 248 0
251
+ 249 1
252
+ 250 0
253
+ 251 1
254
+ 252 1
255
+ 253 1
256
+ 254 1
257
+ 255 0
258
+ 256 0
259
+ 257 1
260
+ 258 1
261
+ 259 0
262
+ 260 1
263
+ 261 0
264
+ 262 1
265
+ 263 1
266
+ 264 1
267
+ 265 1
268
+ 266 1
269
+ 267 1
270
+ 268 1
271
+ 269 1
272
+ 270 1
273
+ 271 0
274
+ 272 1
275
+ 273 1
276
+ 274 1
277
+ 275 0
278
+ 276 1
279
+ 277 1
280
+ 278 1
281
+ 279 1
282
+ 280 0
283
+ 281 0
284
+ 282 0
285
+ 283 1
286
+ 284 0
287
+ 285 0
288
+ 286 0
289
+ 287 1
290
+ 288 1
291
+ 289 0
292
+ 290 1
293
+ 291 0
294
+ 292 1
295
+ 293 1
296
+ 294 1
297
+ 295 1
298
+ 296 0
299
+ 297 0
300
+ 298 1
301
+ 299 1
302
+ 300 1
303
+ 301 0
304
+ 302 0
305
+ 303 0
306
+ 304 0
307
+ 305 0
308
+ 306 0
309
+ 307 1
310
+ 308 1
311
+ 309 0
312
+ 310 0
313
+ 311 1
314
+ 312 1
315
+ 313 1
316
+ 314 0
317
+ 315 1
318
+ 316 1
319
+ 317 1
320
+ 318 1
321
+ 319 1
322
+ 320 1
323
+ 321 0
324
+ 322 1
325
+ 323 1
326
+ 324 1
327
+ 325 1
328
+ 326 1
329
+ 327 1
330
+ 328 1
331
+ 329 1
332
+ 330 1
333
+ 331 1
334
+ 332 1
335
+ 333 1
336
+ 334 0
337
+ 335 1
338
+ 336 0
339
+ 337 1
340
+ 338 1
341
+ 339 1
342
+ 340 1
343
+ 341 0
344
+ 342 1
345
+ 343 1
346
+ 344 1
347
+ 345 0
348
+ 346 1
349
+ 347 1
350
+ 348 1
351
+ 349 1
352
+ 350 1
353
+ 351 1
354
+ 352 0
355
+ 353 1
356
+ 354 1
357
+ 355 1
358
+ 356 0
359
+ 357 1
360
+ 358 1
361
+ 359 1
362
+ 360 1
363
+ 361 0
364
+ 362 0
365
+ 363 1
366
+ 364 1
367
+ 365 1
368
+ 366 0
369
+ 367 1
370
+ 368 1
371
+ 369 1
372
+ 370 0
373
+ 371 1
374
+ 372 1
375
+ 373 1
376
+ 374 1
377
+ 375 1
378
+ 376 1
379
+ 377 1
380
+ 378 1
381
+ 379 1
382
+ 380 1
383
+ 381 1
384
+ 382 1
385
+ 383 0
386
+ 384 1
387
+ 385 0
388
+ 386 1
389
+ 387 0
390
+ 388 1
391
+ 389 1
392
+ 390 1
393
+ 391 1
394
+ 392 1
395
+ 393 0
396
+ 394 1
397
+ 395 1
398
+ 396 1
399
+ 397 1
400
+ 398 1
401
+ 399 1
402
+ 400 1
403
+ 401 0
404
+ 402 0
405
+ 403 0
406
+ 404 1
407
+ 405 1
408
+ 406 0
409
+ 407 0
410
+ 408 1
411
+ 409 1
412
+ 410 1
413
+ 411 0
414
+ 412 0
415
+ 413 1
416
+ 414 1
417
+ 415 0
418
+ 416 1
419
+ 417 0
420
+ 418 0
421
+ 419 1
422
+ 420 1
423
+ 421 1
424
+ 422 1
425
+ 423 1
426
+ 424 1
427
+ 425 1
428
+ 426 1
429
+ 427 1
430
+ 428 1
431
+ 429 1
432
+ 430 1
433
+ 431 1
434
+ 432 1
435
+ 433 1
436
+ 434 1
437
+ 435 0
438
+ 436 1
439
+ 437 1
440
+ 438 0
441
+ 439 0
442
+ 440 0
443
+ 441 0
444
+ 442 1
445
+ 443 1
446
+ 444 0
447
+ 445 1
448
+ 446 1
449
+ 447 1
450
+ 448 1
451
+ 449 1
452
+ 450 1
453
+ 451 1
454
+ 452 1
455
+ 453 0
456
+ 454 1
457
+ 455 1
458
+ 456 0
459
+ 457 1
460
+ 458 1
461
+ 459 1
462
+ 460 1
463
+ 461 1
464
+ 462 1
465
+ 463 1
466
+ 464 1
467
+ 465 0
468
+ 466 1
469
+ 467 1
470
+ 468 1
471
+ 469 0
472
+ 470 0
473
+ 471 1
474
+ 472 1
475
+ 473 1
476
+ 474 0
477
+ 475 1
478
+ 476 1
479
+ 477 0
480
+ 478 1
481
+ 479 1
482
+ 480 1
483
+ 481 1
484
+ 482 0
485
+ 483 1
486
+ 484 1
487
+ 485 1
488
+ 486 1
489
+ 487 1
490
+ 488 1
491
+ 489 1
492
+ 490 1
493
+ 491 1
494
+ 492 1
495
+ 493 1
496
+ 494 1
497
+ 495 1
498
+ 496 0
499
+ 497 0
500
+ 498 1
501
+ 499 1
502
+ 500 1
503
+ 501 1
504
+ 502 0
505
+ 503 0
506
+ 504 1
507
+ 505 0
508
+ 506 1
509
+ 507 1
510
+ 508 1
511
+ 509 1
512
+ 510 0
513
+ 511 0
514
+ 512 1
515
+ 513 1
516
+ 514 1
517
+ 515 1
518
+ 516 0
519
+ 517 1
520
+ 518 1
521
+ 519 1
522
+ 520 0
523
+ 521 1
524
+ 522 1
525
+ 523 1
526
+ 524 0
527
+ 525 1
528
+ 526 0
529
+ 527 1
530
+ 528 1
531
+ 529 0
532
+ 530 1
533
+ 531 1
534
+ 532 1
535
+ 533 1
536
+ 534 1
537
+ 535 0
538
+ 536 1
539
+ 537 0
540
+ 538 1
541
+ 539 0
542
+ 540 1
543
+ 541 1
544
+ 542 0
545
+ 543 1
546
+ 544 1
547
+ 545 1
548
+ 546 1
549
+ 547 1
550
+ 548 1
551
+ 549 1
552
+ 550 1
553
+ 551 1
554
+ 552 0
555
+ 553 1
556
+ 554 0
557
+ 555 1
558
+ 556 1
559
+ 557 0
560
+ 558 0
561
+ 559 1
562
+ 560 1
563
+ 561 1
564
+ 562 1
565
+ 563 1
566
+ 564 1
567
+ 565 0
568
+ 566 1
569
+ 567 0
570
+ 568 1
571
+ 569 1
572
+ 570 1
573
+ 571 1
574
+ 572 0
575
+ 573 1
576
+ 574 1
577
+ 575 1
578
+ 576 0
579
+ 577 1
580
+ 578 1
581
+ 579 0
582
+ 580 0
583
+ 581 0
584
+ 582 0
585
+ 583 0
586
+ 584 1
587
+ 585 1
588
+ 586 1
589
+ 587 1
590
+ 588 0
591
+ 589 1
592
+ 590 0
593
+ 591 1
594
+ 592 0
595
+ 593 1
596
+ 594 1
597
+ 595 1
598
+ 596 0
599
+ 597 1
600
+ 598 0
601
+ 599 1
602
+ 600 1
603
+ 601 1
604
+ 602 1
605
+ 603 1
606
+ 604 1
607
+ 605 1
608
+ 606 0
609
+ 607 1
610
+ 608 1
611
+ 609 1
612
+ 610 1
613
+ 611 1
614
+ 612 1
615
+ 613 1
616
+ 614 1
617
+ 615 1
618
+ 616 0
619
+ 617 1
620
+ 618 1
621
+ 619 1
622
+ 620 1
623
+ 621 1
624
+ 622 0
625
+ 623 1
626
+ 624 1
627
+ 625 1
628
+ 626 0
629
+ 627 1
630
+ 628 1
631
+ 629 0
632
+ 630 1
633
+ 631 1
634
+ 632 1
635
+ 633 1
636
+ 634 1
637
+ 635 0
638
+ 636 1
639
+ 637 0
640
+ 638 1
641
+ 639 1
642
+ 640 1
643
+ 641 0
644
+ 642 0
645
+ 643 0
646
+ 644 1
647
+ 645 1
648
+ 646 1
649
+ 647 1
650
+ 648 1
651
+ 649 0
652
+ 650 1
653
+ 651 1
654
+ 652 0
655
+ 653 0
656
+ 654 0
657
+ 655 0
658
+ 656 1
659
+ 657 0
660
+ 658 1
661
+ 659 1
662
+ 660 0
663
+ 661 1
664
+ 662 1
665
+ 663 1
666
+ 664 0
667
+ 665 1
668
+ 666 1
669
+ 667 1
670
+ 668 0
671
+ 669 1
672
+ 670 1
673
+ 671 1
674
+ 672 1
675
+ 673 1
676
+ 674 1
677
+ 675 0
678
+ 676 1
679
+ 677 1
680
+ 678 1
681
+ 679 0
682
+ 680 1
683
+ 681 1
684
+ 682 0
685
+ 683 1
686
+ 684 1
687
+ 685 1
688
+ 686 1
689
+ 687 1
690
+ 688 0
691
+ 689 1
692
+ 690 1
693
+ 691 0
694
+ 692 0
695
+ 693 1
696
+ 694 1
697
+ 695 1
698
+ 696 1
699
+ 697 0
700
+ 698 1
701
+ 699 1
702
+ 700 0
703
+ 701 0
704
+ 702 1
705
+ 703 1
706
+ 704 1
707
+ 705 1
708
+ 706 1
709
+ 707 1
710
+ 708 1
711
+ 709 1
712
+ 710 1
713
+ 711 0
714
+ 712 0
715
+ 713 1
716
+ 714 0
717
+ 715 0
718
+ 716 0
719
+ 717 1
720
+ 718 0
721
+ 719 1
722
+ 720 0
723
+ 721 0
724
+ 722 1
finetune/boolq/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3ca8ff33e534e52b8aec72bb5a89e3f6fbfcae473ea395e225305e360c7e1762
3
+ size 534669003
finetune/boolq/special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
finetune/boolq/structformer_as_hf.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from transformers import PreTrainedModel
6
+ from transformers import PretrainedConfig
7
+ from transformers.modeling_outputs import MaskedLMOutput
8
+ from typing import List
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ MaskedLMOutput,
14
+ SequenceClassifierOutput
15
+ )
16
+
17
+ ##########################################
18
+ # HuggingFace Config
19
+ ##########################################
20
+ class StructformerConfig(PretrainedConfig):
21
+ model_type = "structformer"
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size=768,
26
+ n_context_layers=2,
27
+ nlayers=6,
28
+ ntokens=32000,
29
+ nhead=8,
30
+ dropout=0.1,
31
+ dropatt=0.1,
32
+ relative_bias=False,
33
+ pos_emb=False,
34
+ pad=0,
35
+ n_parser_layers=4,
36
+ conv_size=9,
37
+ relations=('head', 'child'),
38
+ weight_act='softmax',
39
+ **kwargs,
40
+ ):
41
+ self.hidden_size = hidden_size
42
+ self.n_context_layers = n_context_layers
43
+ self.nlayers = nlayers
44
+ self.ntokens = ntokens
45
+ self.nhead = nhead
46
+ self.dropout = dropout
47
+ self.dropatt = dropatt
48
+ self.relative_bias = relative_bias
49
+ self.pos_emb = pos_emb
50
+ self.pad = pad
51
+ self.n_parser_layers = n_parser_layers
52
+ self.conv_size = conv_size
53
+ self.relations = relations
54
+ self.weight_act = weight_act
55
+ super().__init__(**kwargs)
56
+
57
+ ##########################################
58
+ # Custom Layers
59
+ ##########################################
60
+ def _get_activation_fn(activation):
61
+ """Get specified activation function."""
62
+ if activation == "relu":
63
+ return nn.ReLU()
64
+ elif activation == "gelu":
65
+ return nn.GELU()
66
+ elif activation == "leakyrelu":
67
+ return nn.LeakyReLU()
68
+
69
+ raise RuntimeError(
70
+ "activation should be relu/gelu, not {}".format(activation))
71
+
72
+ class Conv1d(nn.Module):
73
+ """1D convolution layer."""
74
+
75
+ def __init__(self, hidden_size, kernel_size, dilation=1):
76
+ """Initialization.
77
+ Args:
78
+ hidden_size: dimension of input embeddings
79
+ kernel_size: convolution kernel size
80
+ dilation: the spacing between the kernel points
81
+ """
82
+ super(Conv1d, self).__init__()
83
+
84
+ if kernel_size % 2 == 0:
85
+ padding = (kernel_size // 2) * dilation
86
+ self.shift = True
87
+ else:
88
+ padding = ((kernel_size - 1) // 2) * dilation
89
+ self.shift = False
90
+ self.conv = nn.Conv1d(
91
+ hidden_size,
92
+ hidden_size,
93
+ kernel_size,
94
+ padding=padding,
95
+ dilation=dilation)
96
+
97
+ def forward(self, x):
98
+ """Compute convolution.
99
+ Args:
100
+ x: input embeddings
101
+ Returns:
102
+ conv_output: convolution results
103
+ """
104
+
105
+ if self.shift:
106
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
107
+ else:
108
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
109
+
110
+ class MultiheadAttention(nn.Module):
111
+ """Multi-head self-attention layer."""
112
+
113
+ def __init__(self,
114
+ embed_dim,
115
+ num_heads,
116
+ dropout=0.,
117
+ bias=True,
118
+ v_proj=True,
119
+ out_proj=True,
120
+ relative_bias=True):
121
+ """Initialization.
122
+ Args:
123
+ embed_dim: dimension of input embeddings
124
+ num_heads: number of self-attention heads
125
+ dropout: dropout rate
126
+ bias: bool, indicate whether include bias for linear transformations
127
+ v_proj: bool, indicate whether project inputs to new values
128
+ out_proj: bool, indicate whether project outputs to new values
129
+ relative_bias: bool, indicate whether use a relative position based
130
+ attention bias
131
+ """
132
+
133
+ super(MultiheadAttention, self).__init__()
134
+ self.embed_dim = embed_dim
135
+
136
+ self.num_heads = num_heads
137
+ self.drop = nn.Dropout(dropout)
138
+ self.head_dim = embed_dim // num_heads
139
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
140
+ "divisible by "
141
+ "num_heads")
142
+
143
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
144
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
145
+ if v_proj:
146
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
147
+ else:
148
+ self.v_proj = nn.Identity()
149
+
150
+ if out_proj:
151
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
152
+ else:
153
+ self.out_proj = nn.Identity()
154
+
155
+ if relative_bias:
156
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
157
+ else:
158
+ self.relative_bias = None
159
+
160
+ self._reset_parameters()
161
+
162
+ def _reset_parameters(self):
163
+ """Initialize attention parameters."""
164
+
165
+ init.xavier_uniform_(self.q_proj.weight)
166
+ init.constant_(self.q_proj.bias, 0.)
167
+
168
+ init.xavier_uniform_(self.k_proj.weight)
169
+ init.constant_(self.k_proj.bias, 0.)
170
+
171
+ if isinstance(self.v_proj, nn.Linear):
172
+ init.xavier_uniform_(self.v_proj.weight)
173
+ init.constant_(self.v_proj.bias, 0.)
174
+
175
+ if isinstance(self.out_proj, nn.Linear):
176
+ init.xavier_uniform_(self.out_proj.weight)
177
+ init.constant_(self.out_proj.bias, 0.)
178
+
179
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
180
+ """Compute multi-head self-attention.
181
+ Args:
182
+ query: input embeddings
183
+ key_padding_mask: 3D mask that prevents attention to certain positions
184
+ attn_mask: 3D mask that rescale the attention weight at each position
185
+ Returns:
186
+ attn_output: self-attention output
187
+ """
188
+
189
+ length, bsz, embed_dim = query.size()
190
+ assert embed_dim == self.embed_dim
191
+
192
+ head_dim = embed_dim // self.num_heads
193
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
194
+ "divisible by num_heads")
195
+ scaling = float(head_dim)**-0.5
196
+
197
+ q = self.q_proj(query)
198
+ k = self.k_proj(query)
199
+ v = self.v_proj(query)
200
+
201
+ q = q * scaling
202
+
203
+ if attn_mask is not None:
204
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
205
+ query.size(0), query.size(0)]
206
+
207
+ q = q.contiguous().view(length, bsz * self.num_heads,
208
+ head_dim).transpose(0, 1)
209
+ k = k.contiguous().view(length, bsz * self.num_heads,
210
+ head_dim).transpose(0, 1)
211
+ v = v.contiguous().view(length, bsz * self.num_heads,
212
+ head_dim).transpose(0, 1)
213
+
214
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
215
+ assert list(
216
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
217
+
218
+ if self.relative_bias is not None:
219
+ pos = torch.arange(length, device=query.device)
220
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
221
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
222
+ -1)
223
+
224
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
225
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
226
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
227
+ attn_output_weights = attn_output_weights + relative_bias
228
+
229
+ if key_padding_mask is not None:
230
+ attn_output_weights = attn_output_weights + key_padding_mask
231
+
232
+ if attn_mask is None:
233
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
234
+ else:
235
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
236
+
237
+ attn_output_weights = self.drop(attn_output_weights)
238
+
239
+ attn_output = torch.bmm(attn_output_weights, v)
240
+
241
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
242
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
243
+ length, bsz, embed_dim)
244
+ attn_output = self.out_proj(attn_output)
245
+
246
+ return attn_output
247
+
248
+ class TransformerLayer(nn.Module):
249
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
250
+
251
+ def __init__(self,
252
+ d_model,
253
+ nhead,
254
+ dim_feedforward=2048,
255
+ dropout=0.1,
256
+ dropatt=0.1,
257
+ activation="leakyrelu",
258
+ relative_bias=True):
259
+ """Initialization.
260
+ Args:
261
+ d_model: dimension of inputs
262
+ nhead: number of self-attention heads
263
+ dim_feedforward: dimension of hidden layer in feedforward layer
264
+ dropout: dropout rate
265
+ dropatt: drop attention rate
266
+ activation: activation function
267
+ relative_bias: bool, indicate whether use a relative position based
268
+ attention bias
269
+ """
270
+
271
+ super(TransformerLayer, self).__init__()
272
+
273
+ self.self_attn = MultiheadAttention(
274
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
275
+
276
+ # Implementation of Feedforward model
277
+ self.feedforward = nn.Sequential(
278
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
279
+ _get_activation_fn(activation), nn.Dropout(dropout),
280
+ nn.Linear(dim_feedforward, d_model))
281
+
282
+ self.norm = nn.LayerNorm(d_model)
283
+ self.dropout1 = nn.Dropout(dropout)
284
+ self.dropout2 = nn.Dropout(dropout)
285
+
286
+ self.nhead = nhead
287
+
288
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
289
+ """Pass the input through the encoder layer.
290
+ Args:
291
+ src: the sequence to the encoder layer (required).
292
+ attn_mask: the mask for the src sequence (optional).
293
+ key_padding_mask: the mask for the src keys per batch (optional).
294
+ Returns:
295
+ src3: the output of transformer layer, share the same shape as src.
296
+ """
297
+ src2 = self.self_attn(
298
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
299
+ src2 = src + self.dropout1(src2)
300
+ src3 = self.feedforward(src2)
301
+ src3 = src2 + self.dropout2(src3)
302
+
303
+ return src3
304
+
305
+
306
+
307
+ class RobertaClassificationHead(nn.Module):
308
+ """Head for sentence-level classification tasks."""
309
+
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
313
+ classifier_dropout = (
314
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
315
+ )
316
+ self.dropout = nn.Dropout(classifier_dropout)
317
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
318
+
319
+ def forward(self, features, **kwargs):
320
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
321
+ x = self.dropout(x)
322
+ x = self.dense(x)
323
+ x = torch.tanh(x)
324
+ x = self.dropout(x)
325
+ x = self.out_proj(x)
326
+ return x
327
+
328
+
329
+ ##########################################
330
+ # Custom Models
331
+ ##########################################
332
+ def cumprod(x, reverse=False, exclusive=False):
333
+ """cumulative product."""
334
+ if reverse:
335
+ x = x.flip([-1])
336
+
337
+ if exclusive:
338
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
339
+
340
+ cx = x.cumprod(-1)
341
+
342
+ if reverse:
343
+ cx = cx.flip([-1])
344
+ return cx
345
+
346
+ def cumsum(x, reverse=False, exclusive=False):
347
+ """cumulative sum."""
348
+ bsz, _, length = x.size()
349
+ device = x.device
350
+ if reverse:
351
+ if exclusive:
352
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
353
+ else:
354
+ w = torch.ones([bsz, length, length], device=device).tril(0)
355
+ cx = torch.bmm(x, w)
356
+ else:
357
+ if exclusive:
358
+ w = torch.ones([bsz, length, length], device=device).triu(1)
359
+ else:
360
+ w = torch.ones([bsz, length, length], device=device).triu(0)
361
+ cx = torch.bmm(x, w)
362
+ return cx
363
+
364
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
365
+ """cumulative min."""
366
+ if reverse:
367
+ if exclusive:
368
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
369
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
370
+ else:
371
+ if exclusive:
372
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
373
+ x = x.cummin(-1)[0]
374
+ return x
375
+
376
+ class Transformer(nn.Module):
377
+ """Transformer model."""
378
+
379
+ def __init__(self,
380
+ hidden_size,
381
+ nlayers,
382
+ ntokens,
383
+ nhead=8,
384
+ dropout=0.1,
385
+ dropatt=0.1,
386
+ relative_bias=True,
387
+ pos_emb=False,
388
+ pad=0):
389
+ """Initialization.
390
+ Args:
391
+ hidden_size: dimension of inputs and hidden states
392
+ nlayers: number of layers
393
+ ntokens: number of output categories
394
+ nhead: number of self-attention heads
395
+ dropout: dropout rate
396
+ dropatt: drop attention rate
397
+ relative_bias: bool, indicate whether use a relative position based
398
+ attention bias
399
+ pos_emb: bool, indicate whether use a learnable positional embedding
400
+ pad: pad token index
401
+ """
402
+
403
+ super(Transformer, self).__init__()
404
+
405
+ self.drop = nn.Dropout(dropout)
406
+
407
+ self.emb = nn.Embedding(ntokens, hidden_size)
408
+ if pos_emb:
409
+ self.pos_emb = nn.Embedding(500, hidden_size)
410
+
411
+ self.layers = nn.ModuleList([
412
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
413
+ dropatt=dropatt, relative_bias=relative_bias)
414
+ for _ in range(nlayers)])
415
+
416
+ self.norm = nn.LayerNorm(hidden_size)
417
+
418
+ self.output_layer = nn.Linear(hidden_size, ntokens)
419
+ self.output_layer.weight = self.emb.weight
420
+
421
+ self.init_weights()
422
+
423
+ self.nlayers = nlayers
424
+ self.nhead = nhead
425
+ self.ntokens = ntokens
426
+ self.hidden_size = hidden_size
427
+ self.pad = pad
428
+
429
+ def init_weights(self):
430
+ """Initialize token embedding and output bias."""
431
+ initrange = 0.1
432
+ self.emb.weight.data.uniform_(-initrange, initrange)
433
+ if hasattr(self, 'pos_emb'):
434
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
435
+ self.output_layer.bias.data.fill_(0)
436
+
437
+ def visibility(self, x, device):
438
+ """Mask pad tokens."""
439
+ visibility = (x != self.pad).float()
440
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
441
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
442
+ return visibility.log()
443
+
444
+ def encode(self, x, pos):
445
+ """Standard transformer encode process."""
446
+ h = self.emb(x)
447
+ if hasattr(self, 'pos_emb'):
448
+ h = h + self.pos_emb(pos)
449
+ h_list = []
450
+ visibility = self.visibility(x, x.device)
451
+
452
+ for i in range(self.nlayers):
453
+ h_list.append(h)
454
+ h = self.layers[i](
455
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
456
+
457
+ output = h
458
+ h_array = torch.stack(h_list, dim=2)
459
+
460
+ return output, h_array
461
+
462
+ def forward(self, x, pos):
463
+ """Pass the input through the encoder layer.
464
+ Args:
465
+ x: input tokens (required).
466
+ pos: position for each token (optional).
467
+ Returns:
468
+ output: probability distributions for missing tokens.
469
+ state_dict: parsing results and raw output
470
+ """
471
+
472
+ batch_size, length = x.size()
473
+
474
+ raw_output, _ = self.encode(x, pos)
475
+ raw_output = self.norm(raw_output)
476
+ raw_output = self.drop(raw_output)
477
+
478
+ output = self.output_layer(raw_output)
479
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
480
+
481
+ class StructFormer(Transformer):
482
+ """StructFormer model."""
483
+
484
+ def __init__(self,
485
+ hidden_size,
486
+ n_context_layers,
487
+ nlayers,
488
+ ntokens,
489
+ nhead=8,
490
+ dropout=0.1,
491
+ dropatt=0.1,
492
+ relative_bias=False,
493
+ pos_emb=False,
494
+ pad=0,
495
+ n_parser_layers=4,
496
+ conv_size=9,
497
+ relations=('head', 'child'),
498
+ weight_act='softmax'):
499
+ """Initialization.
500
+ Args:
501
+ hidden_size: dimension of inputs and hidden states
502
+ nlayers: number of layers
503
+ ntokens: number of output categories
504
+ nhead: number of self-attention heads
505
+ dropout: dropout rate
506
+ dropatt: drop attention rate
507
+ relative_bias: bool, indicate whether use a relative position based
508
+ attention bias
509
+ pos_emb: bool, indicate whether use a learnable positional embedding
510
+ pad: pad token index
511
+ n_parser_layers: number of parsing layers
512
+ conv_size: convolution kernel size for parser
513
+ relations: relations that are used to compute self attention
514
+ weight_act: relations distribution activation function
515
+ """
516
+
517
+ super(StructFormer, self).__init__(
518
+ hidden_size,
519
+ nlayers,
520
+ ntokens,
521
+ nhead=nhead,
522
+ dropout=dropout,
523
+ dropatt=dropatt,
524
+ relative_bias=relative_bias,
525
+ pos_emb=pos_emb,
526
+ pad=pad)
527
+
528
+ if n_context_layers > 0:
529
+ self.context_layers = nn.ModuleList([
530
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
531
+ dropatt=dropatt, relative_bias=relative_bias)
532
+ for _ in range(n_context_layers)])
533
+
534
+ self.parser_layers = nn.ModuleList([
535
+ nn.Sequential(Conv1d(hidden_size, conv_size),
536
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
537
+ nn.Tanh()) for i in range(n_parser_layers)])
538
+
539
+ self.distance_ff = nn.Sequential(
540
+ Conv1d(hidden_size, 2),
541
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
542
+ nn.Linear(hidden_size, 1))
543
+
544
+ self.height_ff = nn.Sequential(
545
+ nn.Linear(hidden_size, hidden_size),
546
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
547
+ nn.Linear(hidden_size, 1))
548
+
549
+ n_rel = len(relations)
550
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
551
+ self._rel_weight.data.normal_(0, 0.1)
552
+
553
+ self._scaler = nn.Parameter(torch.zeros(2))
554
+
555
+ self.n_parse_layers = n_parser_layers
556
+ self.n_context_layers = n_context_layers
557
+ self.weight_act = weight_act
558
+ self.relations = relations
559
+
560
+ @property
561
+ def scaler(self):
562
+ return self._scaler.exp()
563
+
564
+ @property
565
+ def rel_weight(self):
566
+ if self.weight_act == 'sigmoid':
567
+ return torch.sigmoid(self._rel_weight)
568
+ elif self.weight_act == 'softmax':
569
+ return torch.softmax(self._rel_weight, dim=-1)
570
+
571
+ def parse(self, x, pos, embeds=None):
572
+ """Parse input sentence.
573
+ Args:
574
+ x: input tokens (required).
575
+ pos: position for each token (optional).
576
+ Returns:
577
+ distance: syntactic distance
578
+ height: syntactic height
579
+ """
580
+
581
+ mask = (x != self.pad)
582
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
583
+
584
+
585
+ if embeds is not None:
586
+ h = embeds
587
+ else:
588
+ h = self.emb(x)
589
+
590
+ for i in range(self.n_parse_layers):
591
+ h = h.masked_fill(~mask[:, :, None], 0)
592
+ h = self.parser_layers[i](h)
593
+
594
+ height = self.height_ff(h).squeeze(-1)
595
+ height.masked_fill_(~mask, -1e9)
596
+
597
+ distance = self.distance_ff(h).squeeze(-1)
598
+ distance.masked_fill_(~mask_shifted, 1e9)
599
+
600
+ # Calbrating the distance and height to the same level
601
+ length = distance.size(1)
602
+ height_max = height[:, None, :].expand(-1, length, -1)
603
+ height_max = torch.cummax(
604
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
605
+ dim=-1)[0].triu(0)
606
+
607
+ margin_left = torch.relu(
608
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
609
+ margin_right = torch.relu(distance[:, None, :] - height_max)
610
+ margin = torch.where(margin_left > margin_right, margin_right,
611
+ margin_left).triu(0)
612
+
613
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
614
+ margin.masked_fill_(~margin_mask, 0)
615
+ margin = margin.max()
616
+
617
+ distance = distance - margin
618
+
619
+ return distance, height
620
+
621
+ def compute_block(self, distance, height):
622
+ """Compute constituents from distance and height."""
623
+
624
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
625
+
626
+ gamma = torch.sigmoid(-beta_logits)
627
+ ones = torch.ones_like(gamma)
628
+
629
+ block_mask_left = cummin(
630
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
631
+ block_mask_left = block_mask_left - F.pad(
632
+ block_mask_left[:, :, :-1], (1, 0), value=0)
633
+ block_mask_left.tril_(0)
634
+
635
+ block_mask_right = cummin(
636
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
637
+ block_mask_right = block_mask_right - F.pad(
638
+ block_mask_right[:, :, 1:], (0, 1), value=0)
639
+ block_mask_right.triu_(0)
640
+
641
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
642
+ block = cumsum(block_mask_left).tril(0) + cumsum(
643
+ block_mask_right, reverse=True).triu(1)
644
+
645
+ return block_p, block
646
+
647
+ def compute_head(self, height):
648
+ """Estimate head for each constituent."""
649
+
650
+ _, length = height.size()
651
+ head_logits = height * self.scaler[1]
652
+ index = torch.arange(length, device=height.device)
653
+
654
+ mask = (index[:, None, None] <= index[None, None, :]) * (
655
+ index[None, None, :] <= index[None, :, None])
656
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
657
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
658
+
659
+ head_p = torch.softmax(head_logits, dim=-1)
660
+
661
+ return head_p
662
+
663
+ def generate_mask(self, x, distance, height):
664
+ """Compute head and cibling distribution for each token."""
665
+
666
+ bsz, length = x.size()
667
+
668
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
669
+ eye = eye[None, :, :].expand((bsz, -1, -1))
670
+
671
+ block_p, block = self.compute_block(distance, height)
672
+ head_p = self.compute_head(height)
673
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
674
+ head = head.masked_fill(eye, 0)
675
+ child = head.transpose(1, 2)
676
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
677
+
678
+ rel_list = []
679
+ if 'head' in self.relations:
680
+ rel_list.append(head)
681
+ if 'child' in self.relations:
682
+ rel_list.append(child)
683
+ if 'cibling' in self.relations:
684
+ rel_list.append(cibling)
685
+
686
+ rel = torch.stack(rel_list, dim=1)
687
+
688
+ rel_weight = self.rel_weight
689
+
690
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
691
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
692
+
693
+ return att_mask, cibling, head, block
694
+
695
+ def encode(self, x, pos, att_mask=None, context_layers=False):
696
+ """Structformer encoding process."""
697
+
698
+ if context_layers:
699
+ """Standard transformer encode process."""
700
+ h = self.emb(x)
701
+ if hasattr(self, 'pos_emb'):
702
+ h = h + self.pos_emb(pos)
703
+ h_list = []
704
+ visibility = self.visibility(x, x.device)
705
+ for i in range(self.n_context_layers):
706
+ h_list.append(h)
707
+ h = self.context_layers[i](
708
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
709
+
710
+ output = h
711
+ h_array = torch.stack(h_list, dim=2)
712
+ return output
713
+
714
+ else:
715
+ visibility = self.visibility(x, x.device)
716
+ h = self.emb(x)
717
+ if hasattr(self, 'pos_emb'):
718
+ assert pos.max() < 500
719
+ h = h + self.pos_emb(pos)
720
+ for i in range(self.nlayers):
721
+ h = self.layers[i](
722
+ h.transpose(0, 1), attn_mask=att_mask[i],
723
+ key_padding_mask=visibility).transpose(0, 1)
724
+ return h
725
+
726
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
727
+
728
+ x = input_ids
729
+ batch_size, length = x.size()
730
+
731
+ if position_ids is None:
732
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
733
+
734
+ context_layers_output = None
735
+ if self.n_context_layers > 0:
736
+ context_layers_output = self.encode(x, pos, context_layers=True)
737
+
738
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
739
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
740
+
741
+ raw_output = self.encode(x, pos, att_mask)
742
+ raw_output = self.norm(raw_output)
743
+ raw_output = self.drop(raw_output)
744
+
745
+ output = self.output_layer(raw_output)
746
+
747
+ loss = None
748
+ if labels is not None:
749
+ loss_fct = nn.CrossEntropyLoss()
750
+ loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1))
751
+
752
+ return MaskedLMOutput(
753
+ loss=loss, # shape: 1
754
+ logits=output, # shape: (batch_size * length, ntokens)
755
+ hidden_states=None,
756
+ attentions=None,
757
+ )
758
+
759
+
760
+
761
+
762
+ class StructFormerClassification(Transformer):
763
+ """StructFormer model."""
764
+
765
+ def __init__(self,
766
+ hidden_size,
767
+ n_context_layers,
768
+ nlayers,
769
+ ntokens,
770
+ nhead=8,
771
+ dropout=0.1,
772
+ dropatt=0.1,
773
+ relative_bias=False,
774
+ pos_emb=False,
775
+ pad=0,
776
+ n_parser_layers=4,
777
+ conv_size=9,
778
+ relations=('head', 'child'),
779
+ weight_act='softmax',
780
+ config=None,
781
+ ):
782
+
783
+
784
+ super(StructFormerClassification, self).__init__(
785
+ hidden_size,
786
+ nlayers,
787
+ ntokens,
788
+ nhead=nhead,
789
+ dropout=dropout,
790
+ dropatt=dropatt,
791
+ relative_bias=relative_bias,
792
+ pos_emb=pos_emb,
793
+ pad=pad)
794
+
795
+ self.num_labels = config.num_labels
796
+ self.config = config
797
+
798
+ if n_context_layers > 0:
799
+ self.context_layers = nn.ModuleList([
800
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
801
+ dropatt=dropatt, relative_bias=relative_bias)
802
+ for _ in range(n_context_layers)])
803
+
804
+ self.parser_layers = nn.ModuleList([
805
+ nn.Sequential(Conv1d(hidden_size, conv_size),
806
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
807
+ nn.Tanh()) for i in range(n_parser_layers)])
808
+
809
+ self.distance_ff = nn.Sequential(
810
+ Conv1d(hidden_size, 2),
811
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
812
+ nn.Linear(hidden_size, 1))
813
+
814
+ self.height_ff = nn.Sequential(
815
+ nn.Linear(hidden_size, hidden_size),
816
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
817
+ nn.Linear(hidden_size, 1))
818
+
819
+ n_rel = len(relations)
820
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
821
+ self._rel_weight.data.normal_(0, 0.1)
822
+
823
+ self._scaler = nn.Parameter(torch.zeros(2))
824
+
825
+ self.n_parse_layers = n_parser_layers
826
+ self.n_context_layers = n_context_layers
827
+ self.weight_act = weight_act
828
+ self.relations = relations
829
+
830
+ self.classifier = RobertaClassificationHead(config)
831
+
832
+ @property
833
+ def scaler(self):
834
+ return self._scaler.exp()
835
+
836
+ @property
837
+ def rel_weight(self):
838
+ if self.weight_act == 'sigmoid':
839
+ return torch.sigmoid(self._rel_weight)
840
+ elif self.weight_act == 'softmax':
841
+ return torch.softmax(self._rel_weight, dim=-1)
842
+
843
+ def parse(self, x, pos, embeds=None):
844
+ """Parse input sentence.
845
+ Args:
846
+ x: input tokens (required).
847
+ pos: position for each token (optional).
848
+ Returns:
849
+ distance: syntactic distance
850
+ height: syntactic height
851
+ """
852
+
853
+ mask = (x != self.pad)
854
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
855
+
856
+
857
+ if embeds is not None:
858
+ h = embeds
859
+ else:
860
+ h = self.emb(x)
861
+
862
+ for i in range(self.n_parse_layers):
863
+ h = h.masked_fill(~mask[:, :, None], 0)
864
+ h = self.parser_layers[i](h)
865
+
866
+ height = self.height_ff(h).squeeze(-1)
867
+ height.masked_fill_(~mask, -1e9)
868
+
869
+ distance = self.distance_ff(h).squeeze(-1)
870
+ distance.masked_fill_(~mask_shifted, 1e9)
871
+
872
+ # Calbrating the distance and height to the same level
873
+ length = distance.size(1)
874
+ height_max = height[:, None, :].expand(-1, length, -1)
875
+ height_max = torch.cummax(
876
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
877
+ dim=-1)[0].triu(0)
878
+
879
+ margin_left = torch.relu(
880
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
881
+ margin_right = torch.relu(distance[:, None, :] - height_max)
882
+ margin = torch.where(margin_left > margin_right, margin_right,
883
+ margin_left).triu(0)
884
+
885
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
886
+ margin.masked_fill_(~margin_mask, 0)
887
+ margin = margin.max()
888
+
889
+ distance = distance - margin
890
+
891
+ return distance, height
892
+
893
+ def compute_block(self, distance, height):
894
+ """Compute constituents from distance and height."""
895
+
896
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
897
+
898
+ gamma = torch.sigmoid(-beta_logits)
899
+ ones = torch.ones_like(gamma)
900
+
901
+ block_mask_left = cummin(
902
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
903
+ block_mask_left = block_mask_left - F.pad(
904
+ block_mask_left[:, :, :-1], (1, 0), value=0)
905
+ block_mask_left.tril_(0)
906
+
907
+ block_mask_right = cummin(
908
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
909
+ block_mask_right = block_mask_right - F.pad(
910
+ block_mask_right[:, :, 1:], (0, 1), value=0)
911
+ block_mask_right.triu_(0)
912
+
913
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
914
+ block = cumsum(block_mask_left).tril(0) + cumsum(
915
+ block_mask_right, reverse=True).triu(1)
916
+
917
+ return block_p, block
918
+
919
+ def compute_head(self, height):
920
+ """Estimate head for each constituent."""
921
+
922
+ _, length = height.size()
923
+ head_logits = height * self.scaler[1]
924
+ index = torch.arange(length, device=height.device)
925
+
926
+ mask = (index[:, None, None] <= index[None, None, :]) * (
927
+ index[None, None, :] <= index[None, :, None])
928
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
929
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
930
+
931
+ head_p = torch.softmax(head_logits, dim=-1)
932
+
933
+ return head_p
934
+
935
+ def generate_mask(self, x, distance, height):
936
+ """Compute head and cibling distribution for each token."""
937
+
938
+ bsz, length = x.size()
939
+
940
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
941
+ eye = eye[None, :, :].expand((bsz, -1, -1))
942
+
943
+ block_p, block = self.compute_block(distance, height)
944
+ head_p = self.compute_head(height)
945
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
946
+ head = head.masked_fill(eye, 0)
947
+ child = head.transpose(1, 2)
948
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
949
+
950
+ rel_list = []
951
+ if 'head' in self.relations:
952
+ rel_list.append(head)
953
+ if 'child' in self.relations:
954
+ rel_list.append(child)
955
+ if 'cibling' in self.relations:
956
+ rel_list.append(cibling)
957
+
958
+ rel = torch.stack(rel_list, dim=1)
959
+
960
+ rel_weight = self.rel_weight
961
+
962
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
963
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
964
+
965
+ return att_mask, cibling, head, block
966
+
967
+ def encode(self, x, pos, att_mask=None, context_layers=False):
968
+ """Structformer encoding process."""
969
+
970
+ if context_layers:
971
+ """Standard transformer encode process."""
972
+ h = self.emb(x)
973
+ if hasattr(self, 'pos_emb'):
974
+ h = h + self.pos_emb(pos)
975
+ h_list = []
976
+ visibility = self.visibility(x, x.device)
977
+ for i in range(self.n_context_layers):
978
+ h_list.append(h)
979
+ h = self.context_layers[i](
980
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
981
+
982
+ output = h
983
+ h_array = torch.stack(h_list, dim=2)
984
+ return output
985
+
986
+ else:
987
+ visibility = self.visibility(x, x.device)
988
+ h = self.emb(x)
989
+ if hasattr(self, 'pos_emb'):
990
+ assert pos.max() < 500
991
+ h = h + self.pos_emb(pos)
992
+ for i in range(self.nlayers):
993
+ h = self.layers[i](
994
+ h.transpose(0, 1), attn_mask=att_mask[i],
995
+ key_padding_mask=visibility).transpose(0, 1)
996
+ return h
997
+
998
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
999
+
1000
+ x = input_ids
1001
+ batch_size, length = x.size()
1002
+
1003
+ if position_ids is None:
1004
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
1005
+
1006
+ context_layers_output = None
1007
+ if self.n_context_layers > 0:
1008
+ context_layers_output = self.encode(x, pos, context_layers=True)
1009
+
1010
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
1011
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
1012
+
1013
+ raw_output = self.encode(x, pos, att_mask)
1014
+ raw_output = self.norm(raw_output)
1015
+ raw_output = self.drop(raw_output)
1016
+
1017
+ #output = self.output_layer(raw_output)
1018
+ logits = self.classifier(raw_output)
1019
+
1020
+ loss = None
1021
+ if labels is not None:
1022
+ if self.config.problem_type is None:
1023
+ if self.num_labels == 1:
1024
+ self.config.problem_type = "regression"
1025
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1026
+ self.config.problem_type = "single_label_classification"
1027
+ else:
1028
+ self.config.problem_type = "multi_label_classification"
1029
+
1030
+ if self.config.problem_type == "regression":
1031
+ loss_fct = MSELoss()
1032
+ if self.num_labels == 1:
1033
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1034
+ else:
1035
+ loss = loss_fct(logits, labels)
1036
+ elif self.config.problem_type == "single_label_classification":
1037
+ loss_fct = CrossEntropyLoss()
1038
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1039
+ elif self.config.problem_type == "multi_label_classification":
1040
+ loss_fct = BCEWithLogitsLoss()
1041
+ loss = loss_fct(logits, labels)
1042
+
1043
+
1044
+ return SequenceClassifierOutput(
1045
+ loss=loss,
1046
+ logits=logits,
1047
+ hidden_states=None,
1048
+ attentions=None,
1049
+ )
1050
+
1051
+
1052
+
1053
+ ##########################################
1054
+ # HuggingFace Model
1055
+ ##########################################
1056
+ class StructformerModel(PreTrainedModel):
1057
+ config_class = StructformerConfig
1058
+
1059
+ def __init__(self, config):
1060
+ super().__init__(config)
1061
+ self.model = StructFormer(
1062
+ hidden_size=config.hidden_size,
1063
+ n_context_layers=config.n_context_layers,
1064
+ nlayers=config.nlayers,
1065
+ ntokens=config.ntokens,
1066
+ nhead=config.nhead,
1067
+ dropout=config.dropout,
1068
+ dropatt=config.dropatt,
1069
+ relative_bias=config.relative_bias,
1070
+ pos_emb=config.pos_emb,
1071
+ pad=config.pad,
1072
+ n_parser_layers=config.n_parser_layers,
1073
+ conv_size=config.conv_size,
1074
+ relations=config.relations,
1075
+ weight_act=config.weight_act
1076
+ )
1077
+
1078
+ def forward(self, input_ids, labels=None, **kwargs):
1079
+ return self.model(input_ids, labels=labels, **kwargs)
1080
+
1081
+
1082
+
1083
+ class StructformerModelForSequenceClassification(PreTrainedModel):
1084
+ config_class = StructformerConfig
1085
+ def __init__(self, config):
1086
+ super().__init__(config)
1087
+ self.model = StructFormerClassification(
1088
+ hidden_size=config.hidden_size,
1089
+ n_context_layers=config.n_context_layers,
1090
+ nlayers=config.nlayers,
1091
+ ntokens=config.ntokens,
1092
+ nhead=config.nhead,
1093
+ dropout=config.dropout,
1094
+ dropatt=config.dropatt,
1095
+ relative_bias=config.relative_bias,
1096
+ pos_emb=config.pos_emb,
1097
+ pad=config.pad,
1098
+ n_parser_layers=config.n_parser_layers,
1099
+ conv_size=config.conv_size,
1100
+ relations=config.relations,
1101
+ weight_act=config.weight_act,
1102
+ config=config)
1103
+
1104
+ def _init_weights(self, module):
1105
+ """Initialize the weights"""
1106
+ if isinstance(module, nn.Linear):
1107
+ # Slightly different from the TF version which uses truncated_normal for initialization
1108
+ # cf https://github.com/pytorch/pytorch/pull/5617
1109
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1110
+ if module.bias is not None:
1111
+ module.bias.data.zero_()
1112
+ elif isinstance(module, nn.Embedding):
1113
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1114
+ if module.padding_idx is not None:
1115
+ module.weight.data[module.padding_idx].zero_()
1116
+ elif isinstance(module, nn.LayerNorm):
1117
+ if module.bias is not None:
1118
+ module.bias.data.zero_()
1119
+ module.weight.data.fill_(1.0)
1120
+
1121
+
1122
+ def forward(self, input_ids, labels=None, **kwargs):
1123
+ return self.model(input_ids, labels=labels, **kwargs)
finetune/boolq/tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "cls_token": {
12
+ "__type": "AddedToken",
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "__type": "AddedToken",
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "errors": "replace",
28
+ "mask_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<mask>",
31
+ "lstrip": true,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ },
36
+ "model_max_length": 512,
37
+ "name_or_path": "omarmomen/structformer_s1_final_with_pos",
38
+ "pad_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<pad>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ },
46
+ "sep_token": {
47
+ "__type": "AddedToken",
48
+ "content": "</s>",
49
+ "lstrip": false,
50
+ "normalized": true,
51
+ "rstrip": false,
52
+ "single_word": false
53
+ },
54
+ "special_tokens_map_file": null,
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "trim_offsets": true,
57
+ "unk_token": {
58
+ "__type": "AddedToken",
59
+ "content": "<unk>",
60
+ "lstrip": false,
61
+ "normalized": true,
62
+ "rstrip": false,
63
+ "single_word": false
64
+ }
65
+ }
finetune/boolq/train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 10.0,
3
+ "train_loss": 0.5978388892279731,
4
+ "train_runtime": 99.3528,
5
+ "train_samples": 2072,
6
+ "train_samples_per_second": 208.55,
7
+ "train_steps_per_second": 1.812
8
+ }
finetune/boolq/trainer_state.json ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": null,
3
+ "best_model_checkpoint": null,
4
+ "epoch": 10.0,
5
+ "global_step": 180,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 10.0,
12
+ "step": 180,
13
+ "total_flos": 1729574880337920.0,
14
+ "train_loss": 0.5978388892279731,
15
+ "train_runtime": 99.3528,
16
+ "train_samples_per_second": 208.55,
17
+ "train_steps_per_second": 1.812
18
+ }
19
+ ],
20
+ "max_steps": 180,
21
+ "num_train_epochs": 10,
22
+ "total_flos": 1729574880337920.0,
23
+ "trial_name": null,
24
+ "trial_params": null
25
+ }
finetune/boolq/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:884024d2e1f216c9a9e919729a7ade7c914cf9a7f3aaef5d07b647af45d074a5
3
+ size 3503
finetune/boolq/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
finetune/cola/all_results.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 10.0,
3
+ "eval_accuracy": 0.689892053604126,
4
+ "eval_f1": 0.7931937172774869,
5
+ "eval_loss": 0.762876570224762,
6
+ "eval_mcc": 0.19495355846277174,
7
+ "eval_runtime": 1.9662,
8
+ "eval_samples": 1019,
9
+ "eval_samples_per_second": 518.247,
10
+ "eval_steps_per_second": 65.099,
11
+ "train_loss": 0.3857684093972911,
12
+ "train_runtime": 391.5771,
13
+ "train_samples": 8164,
14
+ "train_samples_per_second": 208.49,
15
+ "train_steps_per_second": 1.762
16
+ }
finetune/cola/checkpoint-400/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "omarmomen/structformer_s1_final_with_pos",
3
+ "architectures": [
4
+ "StructformerModelForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "structformer_as_hf.StructformerConfig",
9
+ "AutoModelForMaskedLM": "structformer_as_hf.StructformerModel",
10
+ "AutoModelForSequenceClassification": "structformer_as_hf.StructformerModelForSequenceClassification"
11
+ },
12
+ "bos_token_id": 0,
13
+ "classifier_dropout": null,
14
+ "conv_size": 9,
15
+ "dropatt": 0.1,
16
+ "dropout": 0.1,
17
+ "eos_token_id": 2,
18
+ "hidden_act": "gelu",
19
+ "hidden_dropout_prob": 0.1,
20
+ "hidden_size": 768,
21
+ "id2label": {
22
+ "0": 0,
23
+ "1": 1
24
+ },
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 3072,
27
+ "label2id": {
28
+ "0": 0,
29
+ "1": 1
30
+ },
31
+ "layer_norm_eps": 1e-05,
32
+ "max_position_embeddings": 514,
33
+ "model_type": "structformer",
34
+ "n_context_layers": 0,
35
+ "n_parser_layers": 4,
36
+ "nhead": 12,
37
+ "nlayers": 12,
38
+ "ntokens": 32000,
39
+ "num_attention_heads": 8,
40
+ "num_hidden_layers": 8,
41
+ "pad": 0,
42
+ "pad_token_id": 1,
43
+ "pos_emb": true,
44
+ "position_embedding_type": "absolute",
45
+ "problem_type": "single_label_classification",
46
+ "relations": [
47
+ "head",
48
+ "child"
49
+ ],
50
+ "relative_bias": false,
51
+ "torch_dtype": "float32",
52
+ "transformers_version": "4.26.1",
53
+ "type_vocab_size": 1,
54
+ "use_cache": true,
55
+ "vocab_size": 32000,
56
+ "weight_act": "softmax"
57
+ }
finetune/cola/checkpoint-400/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
finetune/cola/checkpoint-400/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:09a754a103adc04bc61c97c0cd7151bd44370f268ddea835479252b0ea4cae13
3
+ size 1069068057
finetune/cola/checkpoint-400/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc16b9a152d3c83b6da1a5a69be258541741700af78a1cb388eaf13d5424ed4
3
+ size 534669003
finetune/cola/checkpoint-400/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5769fda57687e02db326bee3281d32f541cdc788f737f37c47c3a91239b699cc
3
+ size 14503
finetune/cola/checkpoint-400/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2b09db5c9b9264697e44562c70af05a12a67180f480992d419fa287214c7be7a
3
+ size 623
finetune/cola/checkpoint-400/special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
finetune/cola/checkpoint-400/structformer_as_hf.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from transformers import PreTrainedModel
6
+ from transformers import PretrainedConfig
7
+ from transformers.modeling_outputs import MaskedLMOutput
8
+ from typing import List
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ MaskedLMOutput,
14
+ SequenceClassifierOutput
15
+ )
16
+
17
+ ##########################################
18
+ # HuggingFace Config
19
+ ##########################################
20
+ class StructformerConfig(PretrainedConfig):
21
+ model_type = "structformer"
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size=768,
26
+ n_context_layers=2,
27
+ nlayers=6,
28
+ ntokens=32000,
29
+ nhead=8,
30
+ dropout=0.1,
31
+ dropatt=0.1,
32
+ relative_bias=False,
33
+ pos_emb=False,
34
+ pad=0,
35
+ n_parser_layers=4,
36
+ conv_size=9,
37
+ relations=('head', 'child'),
38
+ weight_act='softmax',
39
+ **kwargs,
40
+ ):
41
+ self.hidden_size = hidden_size
42
+ self.n_context_layers = n_context_layers
43
+ self.nlayers = nlayers
44
+ self.ntokens = ntokens
45
+ self.nhead = nhead
46
+ self.dropout = dropout
47
+ self.dropatt = dropatt
48
+ self.relative_bias = relative_bias
49
+ self.pos_emb = pos_emb
50
+ self.pad = pad
51
+ self.n_parser_layers = n_parser_layers
52
+ self.conv_size = conv_size
53
+ self.relations = relations
54
+ self.weight_act = weight_act
55
+ super().__init__(**kwargs)
56
+
57
+ ##########################################
58
+ # Custom Layers
59
+ ##########################################
60
+ def _get_activation_fn(activation):
61
+ """Get specified activation function."""
62
+ if activation == "relu":
63
+ return nn.ReLU()
64
+ elif activation == "gelu":
65
+ return nn.GELU()
66
+ elif activation == "leakyrelu":
67
+ return nn.LeakyReLU()
68
+
69
+ raise RuntimeError(
70
+ "activation should be relu/gelu, not {}".format(activation))
71
+
72
+ class Conv1d(nn.Module):
73
+ """1D convolution layer."""
74
+
75
+ def __init__(self, hidden_size, kernel_size, dilation=1):
76
+ """Initialization.
77
+ Args:
78
+ hidden_size: dimension of input embeddings
79
+ kernel_size: convolution kernel size
80
+ dilation: the spacing between the kernel points
81
+ """
82
+ super(Conv1d, self).__init__()
83
+
84
+ if kernel_size % 2 == 0:
85
+ padding = (kernel_size // 2) * dilation
86
+ self.shift = True
87
+ else:
88
+ padding = ((kernel_size - 1) // 2) * dilation
89
+ self.shift = False
90
+ self.conv = nn.Conv1d(
91
+ hidden_size,
92
+ hidden_size,
93
+ kernel_size,
94
+ padding=padding,
95
+ dilation=dilation)
96
+
97
+ def forward(self, x):
98
+ """Compute convolution.
99
+ Args:
100
+ x: input embeddings
101
+ Returns:
102
+ conv_output: convolution results
103
+ """
104
+
105
+ if self.shift:
106
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
107
+ else:
108
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
109
+
110
+ class MultiheadAttention(nn.Module):
111
+ """Multi-head self-attention layer."""
112
+
113
+ def __init__(self,
114
+ embed_dim,
115
+ num_heads,
116
+ dropout=0.,
117
+ bias=True,
118
+ v_proj=True,
119
+ out_proj=True,
120
+ relative_bias=True):
121
+ """Initialization.
122
+ Args:
123
+ embed_dim: dimension of input embeddings
124
+ num_heads: number of self-attention heads
125
+ dropout: dropout rate
126
+ bias: bool, indicate whether include bias for linear transformations
127
+ v_proj: bool, indicate whether project inputs to new values
128
+ out_proj: bool, indicate whether project outputs to new values
129
+ relative_bias: bool, indicate whether use a relative position based
130
+ attention bias
131
+ """
132
+
133
+ super(MultiheadAttention, self).__init__()
134
+ self.embed_dim = embed_dim
135
+
136
+ self.num_heads = num_heads
137
+ self.drop = nn.Dropout(dropout)
138
+ self.head_dim = embed_dim // num_heads
139
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
140
+ "divisible by "
141
+ "num_heads")
142
+
143
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
144
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
145
+ if v_proj:
146
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
147
+ else:
148
+ self.v_proj = nn.Identity()
149
+
150
+ if out_proj:
151
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
152
+ else:
153
+ self.out_proj = nn.Identity()
154
+
155
+ if relative_bias:
156
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
157
+ else:
158
+ self.relative_bias = None
159
+
160
+ self._reset_parameters()
161
+
162
+ def _reset_parameters(self):
163
+ """Initialize attention parameters."""
164
+
165
+ init.xavier_uniform_(self.q_proj.weight)
166
+ init.constant_(self.q_proj.bias, 0.)
167
+
168
+ init.xavier_uniform_(self.k_proj.weight)
169
+ init.constant_(self.k_proj.bias, 0.)
170
+
171
+ if isinstance(self.v_proj, nn.Linear):
172
+ init.xavier_uniform_(self.v_proj.weight)
173
+ init.constant_(self.v_proj.bias, 0.)
174
+
175
+ if isinstance(self.out_proj, nn.Linear):
176
+ init.xavier_uniform_(self.out_proj.weight)
177
+ init.constant_(self.out_proj.bias, 0.)
178
+
179
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
180
+ """Compute multi-head self-attention.
181
+ Args:
182
+ query: input embeddings
183
+ key_padding_mask: 3D mask that prevents attention to certain positions
184
+ attn_mask: 3D mask that rescale the attention weight at each position
185
+ Returns:
186
+ attn_output: self-attention output
187
+ """
188
+
189
+ length, bsz, embed_dim = query.size()
190
+ assert embed_dim == self.embed_dim
191
+
192
+ head_dim = embed_dim // self.num_heads
193
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
194
+ "divisible by num_heads")
195
+ scaling = float(head_dim)**-0.5
196
+
197
+ q = self.q_proj(query)
198
+ k = self.k_proj(query)
199
+ v = self.v_proj(query)
200
+
201
+ q = q * scaling
202
+
203
+ if attn_mask is not None:
204
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
205
+ query.size(0), query.size(0)]
206
+
207
+ q = q.contiguous().view(length, bsz * self.num_heads,
208
+ head_dim).transpose(0, 1)
209
+ k = k.contiguous().view(length, bsz * self.num_heads,
210
+ head_dim).transpose(0, 1)
211
+ v = v.contiguous().view(length, bsz * self.num_heads,
212
+ head_dim).transpose(0, 1)
213
+
214
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
215
+ assert list(
216
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
217
+
218
+ if self.relative_bias is not None:
219
+ pos = torch.arange(length, device=query.device)
220
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
221
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
222
+ -1)
223
+
224
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
225
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
226
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
227
+ attn_output_weights = attn_output_weights + relative_bias
228
+
229
+ if key_padding_mask is not None:
230
+ attn_output_weights = attn_output_weights + key_padding_mask
231
+
232
+ if attn_mask is None:
233
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
234
+ else:
235
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
236
+
237
+ attn_output_weights = self.drop(attn_output_weights)
238
+
239
+ attn_output = torch.bmm(attn_output_weights, v)
240
+
241
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
242
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
243
+ length, bsz, embed_dim)
244
+ attn_output = self.out_proj(attn_output)
245
+
246
+ return attn_output
247
+
248
+ class TransformerLayer(nn.Module):
249
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
250
+
251
+ def __init__(self,
252
+ d_model,
253
+ nhead,
254
+ dim_feedforward=2048,
255
+ dropout=0.1,
256
+ dropatt=0.1,
257
+ activation="leakyrelu",
258
+ relative_bias=True):
259
+ """Initialization.
260
+ Args:
261
+ d_model: dimension of inputs
262
+ nhead: number of self-attention heads
263
+ dim_feedforward: dimension of hidden layer in feedforward layer
264
+ dropout: dropout rate
265
+ dropatt: drop attention rate
266
+ activation: activation function
267
+ relative_bias: bool, indicate whether use a relative position based
268
+ attention bias
269
+ """
270
+
271
+ super(TransformerLayer, self).__init__()
272
+
273
+ self.self_attn = MultiheadAttention(
274
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
275
+
276
+ # Implementation of Feedforward model
277
+ self.feedforward = nn.Sequential(
278
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
279
+ _get_activation_fn(activation), nn.Dropout(dropout),
280
+ nn.Linear(dim_feedforward, d_model))
281
+
282
+ self.norm = nn.LayerNorm(d_model)
283
+ self.dropout1 = nn.Dropout(dropout)
284
+ self.dropout2 = nn.Dropout(dropout)
285
+
286
+ self.nhead = nhead
287
+
288
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
289
+ """Pass the input through the encoder layer.
290
+ Args:
291
+ src: the sequence to the encoder layer (required).
292
+ attn_mask: the mask for the src sequence (optional).
293
+ key_padding_mask: the mask for the src keys per batch (optional).
294
+ Returns:
295
+ src3: the output of transformer layer, share the same shape as src.
296
+ """
297
+ src2 = self.self_attn(
298
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
299
+ src2 = src + self.dropout1(src2)
300
+ src3 = self.feedforward(src2)
301
+ src3 = src2 + self.dropout2(src3)
302
+
303
+ return src3
304
+
305
+
306
+
307
+ class RobertaClassificationHead(nn.Module):
308
+ """Head for sentence-level classification tasks."""
309
+
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
313
+ classifier_dropout = (
314
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
315
+ )
316
+ self.dropout = nn.Dropout(classifier_dropout)
317
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
318
+
319
+ def forward(self, features, **kwargs):
320
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
321
+ x = self.dropout(x)
322
+ x = self.dense(x)
323
+ x = torch.tanh(x)
324
+ x = self.dropout(x)
325
+ x = self.out_proj(x)
326
+ return x
327
+
328
+
329
+ ##########################################
330
+ # Custom Models
331
+ ##########################################
332
+ def cumprod(x, reverse=False, exclusive=False):
333
+ """cumulative product."""
334
+ if reverse:
335
+ x = x.flip([-1])
336
+
337
+ if exclusive:
338
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
339
+
340
+ cx = x.cumprod(-1)
341
+
342
+ if reverse:
343
+ cx = cx.flip([-1])
344
+ return cx
345
+
346
+ def cumsum(x, reverse=False, exclusive=False):
347
+ """cumulative sum."""
348
+ bsz, _, length = x.size()
349
+ device = x.device
350
+ if reverse:
351
+ if exclusive:
352
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
353
+ else:
354
+ w = torch.ones([bsz, length, length], device=device).tril(0)
355
+ cx = torch.bmm(x, w)
356
+ else:
357
+ if exclusive:
358
+ w = torch.ones([bsz, length, length], device=device).triu(1)
359
+ else:
360
+ w = torch.ones([bsz, length, length], device=device).triu(0)
361
+ cx = torch.bmm(x, w)
362
+ return cx
363
+
364
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
365
+ """cumulative min."""
366
+ if reverse:
367
+ if exclusive:
368
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
369
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
370
+ else:
371
+ if exclusive:
372
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
373
+ x = x.cummin(-1)[0]
374
+ return x
375
+
376
+ class Transformer(nn.Module):
377
+ """Transformer model."""
378
+
379
+ def __init__(self,
380
+ hidden_size,
381
+ nlayers,
382
+ ntokens,
383
+ nhead=8,
384
+ dropout=0.1,
385
+ dropatt=0.1,
386
+ relative_bias=True,
387
+ pos_emb=False,
388
+ pad=0):
389
+ """Initialization.
390
+ Args:
391
+ hidden_size: dimension of inputs and hidden states
392
+ nlayers: number of layers
393
+ ntokens: number of output categories
394
+ nhead: number of self-attention heads
395
+ dropout: dropout rate
396
+ dropatt: drop attention rate
397
+ relative_bias: bool, indicate whether use a relative position based
398
+ attention bias
399
+ pos_emb: bool, indicate whether use a learnable positional embedding
400
+ pad: pad token index
401
+ """
402
+
403
+ super(Transformer, self).__init__()
404
+
405
+ self.drop = nn.Dropout(dropout)
406
+
407
+ self.emb = nn.Embedding(ntokens, hidden_size)
408
+ if pos_emb:
409
+ self.pos_emb = nn.Embedding(500, hidden_size)
410
+
411
+ self.layers = nn.ModuleList([
412
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
413
+ dropatt=dropatt, relative_bias=relative_bias)
414
+ for _ in range(nlayers)])
415
+
416
+ self.norm = nn.LayerNorm(hidden_size)
417
+
418
+ self.output_layer = nn.Linear(hidden_size, ntokens)
419
+ self.output_layer.weight = self.emb.weight
420
+
421
+ self.init_weights()
422
+
423
+ self.nlayers = nlayers
424
+ self.nhead = nhead
425
+ self.ntokens = ntokens
426
+ self.hidden_size = hidden_size
427
+ self.pad = pad
428
+
429
+ def init_weights(self):
430
+ """Initialize token embedding and output bias."""
431
+ initrange = 0.1
432
+ self.emb.weight.data.uniform_(-initrange, initrange)
433
+ if hasattr(self, 'pos_emb'):
434
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
435
+ self.output_layer.bias.data.fill_(0)
436
+
437
+ def visibility(self, x, device):
438
+ """Mask pad tokens."""
439
+ visibility = (x != self.pad).float()
440
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
441
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
442
+ return visibility.log()
443
+
444
+ def encode(self, x, pos):
445
+ """Standard transformer encode process."""
446
+ h = self.emb(x)
447
+ if hasattr(self, 'pos_emb'):
448
+ h = h + self.pos_emb(pos)
449
+ h_list = []
450
+ visibility = self.visibility(x, x.device)
451
+
452
+ for i in range(self.nlayers):
453
+ h_list.append(h)
454
+ h = self.layers[i](
455
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
456
+
457
+ output = h
458
+ h_array = torch.stack(h_list, dim=2)
459
+
460
+ return output, h_array
461
+
462
+ def forward(self, x, pos):
463
+ """Pass the input through the encoder layer.
464
+ Args:
465
+ x: input tokens (required).
466
+ pos: position for each token (optional).
467
+ Returns:
468
+ output: probability distributions for missing tokens.
469
+ state_dict: parsing results and raw output
470
+ """
471
+
472
+ batch_size, length = x.size()
473
+
474
+ raw_output, _ = self.encode(x, pos)
475
+ raw_output = self.norm(raw_output)
476
+ raw_output = self.drop(raw_output)
477
+
478
+ output = self.output_layer(raw_output)
479
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
480
+
481
+ class StructFormer(Transformer):
482
+ """StructFormer model."""
483
+
484
+ def __init__(self,
485
+ hidden_size,
486
+ n_context_layers,
487
+ nlayers,
488
+ ntokens,
489
+ nhead=8,
490
+ dropout=0.1,
491
+ dropatt=0.1,
492
+ relative_bias=False,
493
+ pos_emb=False,
494
+ pad=0,
495
+ n_parser_layers=4,
496
+ conv_size=9,
497
+ relations=('head', 'child'),
498
+ weight_act='softmax'):
499
+ """Initialization.
500
+ Args:
501
+ hidden_size: dimension of inputs and hidden states
502
+ nlayers: number of layers
503
+ ntokens: number of output categories
504
+ nhead: number of self-attention heads
505
+ dropout: dropout rate
506
+ dropatt: drop attention rate
507
+ relative_bias: bool, indicate whether use a relative position based
508
+ attention bias
509
+ pos_emb: bool, indicate whether use a learnable positional embedding
510
+ pad: pad token index
511
+ n_parser_layers: number of parsing layers
512
+ conv_size: convolution kernel size for parser
513
+ relations: relations that are used to compute self attention
514
+ weight_act: relations distribution activation function
515
+ """
516
+
517
+ super(StructFormer, self).__init__(
518
+ hidden_size,
519
+ nlayers,
520
+ ntokens,
521
+ nhead=nhead,
522
+ dropout=dropout,
523
+ dropatt=dropatt,
524
+ relative_bias=relative_bias,
525
+ pos_emb=pos_emb,
526
+ pad=pad)
527
+
528
+ if n_context_layers > 0:
529
+ self.context_layers = nn.ModuleList([
530
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
531
+ dropatt=dropatt, relative_bias=relative_bias)
532
+ for _ in range(n_context_layers)])
533
+
534
+ self.parser_layers = nn.ModuleList([
535
+ nn.Sequential(Conv1d(hidden_size, conv_size),
536
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
537
+ nn.Tanh()) for i in range(n_parser_layers)])
538
+
539
+ self.distance_ff = nn.Sequential(
540
+ Conv1d(hidden_size, 2),
541
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
542
+ nn.Linear(hidden_size, 1))
543
+
544
+ self.height_ff = nn.Sequential(
545
+ nn.Linear(hidden_size, hidden_size),
546
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
547
+ nn.Linear(hidden_size, 1))
548
+
549
+ n_rel = len(relations)
550
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
551
+ self._rel_weight.data.normal_(0, 0.1)
552
+
553
+ self._scaler = nn.Parameter(torch.zeros(2))
554
+
555
+ self.n_parse_layers = n_parser_layers
556
+ self.n_context_layers = n_context_layers
557
+ self.weight_act = weight_act
558
+ self.relations = relations
559
+
560
+ @property
561
+ def scaler(self):
562
+ return self._scaler.exp()
563
+
564
+ @property
565
+ def rel_weight(self):
566
+ if self.weight_act == 'sigmoid':
567
+ return torch.sigmoid(self._rel_weight)
568
+ elif self.weight_act == 'softmax':
569
+ return torch.softmax(self._rel_weight, dim=-1)
570
+
571
+ def parse(self, x, pos, embeds=None):
572
+ """Parse input sentence.
573
+ Args:
574
+ x: input tokens (required).
575
+ pos: position for each token (optional).
576
+ Returns:
577
+ distance: syntactic distance
578
+ height: syntactic height
579
+ """
580
+
581
+ mask = (x != self.pad)
582
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
583
+
584
+
585
+ if embeds is not None:
586
+ h = embeds
587
+ else:
588
+ h = self.emb(x)
589
+
590
+ for i in range(self.n_parse_layers):
591
+ h = h.masked_fill(~mask[:, :, None], 0)
592
+ h = self.parser_layers[i](h)
593
+
594
+ height = self.height_ff(h).squeeze(-1)
595
+ height.masked_fill_(~mask, -1e9)
596
+
597
+ distance = self.distance_ff(h).squeeze(-1)
598
+ distance.masked_fill_(~mask_shifted, 1e9)
599
+
600
+ # Calbrating the distance and height to the same level
601
+ length = distance.size(1)
602
+ height_max = height[:, None, :].expand(-1, length, -1)
603
+ height_max = torch.cummax(
604
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
605
+ dim=-1)[0].triu(0)
606
+
607
+ margin_left = torch.relu(
608
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
609
+ margin_right = torch.relu(distance[:, None, :] - height_max)
610
+ margin = torch.where(margin_left > margin_right, margin_right,
611
+ margin_left).triu(0)
612
+
613
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
614
+ margin.masked_fill_(~margin_mask, 0)
615
+ margin = margin.max()
616
+
617
+ distance = distance - margin
618
+
619
+ return distance, height
620
+
621
+ def compute_block(self, distance, height):
622
+ """Compute constituents from distance and height."""
623
+
624
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
625
+
626
+ gamma = torch.sigmoid(-beta_logits)
627
+ ones = torch.ones_like(gamma)
628
+
629
+ block_mask_left = cummin(
630
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
631
+ block_mask_left = block_mask_left - F.pad(
632
+ block_mask_left[:, :, :-1], (1, 0), value=0)
633
+ block_mask_left.tril_(0)
634
+
635
+ block_mask_right = cummin(
636
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
637
+ block_mask_right = block_mask_right - F.pad(
638
+ block_mask_right[:, :, 1:], (0, 1), value=0)
639
+ block_mask_right.triu_(0)
640
+
641
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
642
+ block = cumsum(block_mask_left).tril(0) + cumsum(
643
+ block_mask_right, reverse=True).triu(1)
644
+
645
+ return block_p, block
646
+
647
+ def compute_head(self, height):
648
+ """Estimate head for each constituent."""
649
+
650
+ _, length = height.size()
651
+ head_logits = height * self.scaler[1]
652
+ index = torch.arange(length, device=height.device)
653
+
654
+ mask = (index[:, None, None] <= index[None, None, :]) * (
655
+ index[None, None, :] <= index[None, :, None])
656
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
657
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
658
+
659
+ head_p = torch.softmax(head_logits, dim=-1)
660
+
661
+ return head_p
662
+
663
+ def generate_mask(self, x, distance, height):
664
+ """Compute head and cibling distribution for each token."""
665
+
666
+ bsz, length = x.size()
667
+
668
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
669
+ eye = eye[None, :, :].expand((bsz, -1, -1))
670
+
671
+ block_p, block = self.compute_block(distance, height)
672
+ head_p = self.compute_head(height)
673
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
674
+ head = head.masked_fill(eye, 0)
675
+ child = head.transpose(1, 2)
676
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
677
+
678
+ rel_list = []
679
+ if 'head' in self.relations:
680
+ rel_list.append(head)
681
+ if 'child' in self.relations:
682
+ rel_list.append(child)
683
+ if 'cibling' in self.relations:
684
+ rel_list.append(cibling)
685
+
686
+ rel = torch.stack(rel_list, dim=1)
687
+
688
+ rel_weight = self.rel_weight
689
+
690
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
691
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
692
+
693
+ return att_mask, cibling, head, block
694
+
695
+ def encode(self, x, pos, att_mask=None, context_layers=False):
696
+ """Structformer encoding process."""
697
+
698
+ if context_layers:
699
+ """Standard transformer encode process."""
700
+ h = self.emb(x)
701
+ if hasattr(self, 'pos_emb'):
702
+ h = h + self.pos_emb(pos)
703
+ h_list = []
704
+ visibility = self.visibility(x, x.device)
705
+ for i in range(self.n_context_layers):
706
+ h_list.append(h)
707
+ h = self.context_layers[i](
708
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
709
+
710
+ output = h
711
+ h_array = torch.stack(h_list, dim=2)
712
+ return output
713
+
714
+ else:
715
+ visibility = self.visibility(x, x.device)
716
+ h = self.emb(x)
717
+ if hasattr(self, 'pos_emb'):
718
+ assert pos.max() < 500
719
+ h = h + self.pos_emb(pos)
720
+ for i in range(self.nlayers):
721
+ h = self.layers[i](
722
+ h.transpose(0, 1), attn_mask=att_mask[i],
723
+ key_padding_mask=visibility).transpose(0, 1)
724
+ return h
725
+
726
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
727
+
728
+ x = input_ids
729
+ batch_size, length = x.size()
730
+
731
+ if position_ids is None:
732
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
733
+
734
+ context_layers_output = None
735
+ if self.n_context_layers > 0:
736
+ context_layers_output = self.encode(x, pos, context_layers=True)
737
+
738
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
739
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
740
+
741
+ raw_output = self.encode(x, pos, att_mask)
742
+ raw_output = self.norm(raw_output)
743
+ raw_output = self.drop(raw_output)
744
+
745
+ output = self.output_layer(raw_output)
746
+
747
+ loss = None
748
+ if labels is not None:
749
+ loss_fct = nn.CrossEntropyLoss()
750
+ loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1))
751
+
752
+ return MaskedLMOutput(
753
+ loss=loss, # shape: 1
754
+ logits=output, # shape: (batch_size * length, ntokens)
755
+ hidden_states=None,
756
+ attentions=None,
757
+ )
758
+
759
+
760
+
761
+
762
+ class StructFormerClassification(Transformer):
763
+ """StructFormer model."""
764
+
765
+ def __init__(self,
766
+ hidden_size,
767
+ n_context_layers,
768
+ nlayers,
769
+ ntokens,
770
+ nhead=8,
771
+ dropout=0.1,
772
+ dropatt=0.1,
773
+ relative_bias=False,
774
+ pos_emb=False,
775
+ pad=0,
776
+ n_parser_layers=4,
777
+ conv_size=9,
778
+ relations=('head', 'child'),
779
+ weight_act='softmax',
780
+ config=None,
781
+ ):
782
+
783
+
784
+ super(StructFormerClassification, self).__init__(
785
+ hidden_size,
786
+ nlayers,
787
+ ntokens,
788
+ nhead=nhead,
789
+ dropout=dropout,
790
+ dropatt=dropatt,
791
+ relative_bias=relative_bias,
792
+ pos_emb=pos_emb,
793
+ pad=pad)
794
+
795
+ self.num_labels = config.num_labels
796
+ self.config = config
797
+
798
+ if n_context_layers > 0:
799
+ self.context_layers = nn.ModuleList([
800
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
801
+ dropatt=dropatt, relative_bias=relative_bias)
802
+ for _ in range(n_context_layers)])
803
+
804
+ self.parser_layers = nn.ModuleList([
805
+ nn.Sequential(Conv1d(hidden_size, conv_size),
806
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
807
+ nn.Tanh()) for i in range(n_parser_layers)])
808
+
809
+ self.distance_ff = nn.Sequential(
810
+ Conv1d(hidden_size, 2),
811
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
812
+ nn.Linear(hidden_size, 1))
813
+
814
+ self.height_ff = nn.Sequential(
815
+ nn.Linear(hidden_size, hidden_size),
816
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
817
+ nn.Linear(hidden_size, 1))
818
+
819
+ n_rel = len(relations)
820
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
821
+ self._rel_weight.data.normal_(0, 0.1)
822
+
823
+ self._scaler = nn.Parameter(torch.zeros(2))
824
+
825
+ self.n_parse_layers = n_parser_layers
826
+ self.n_context_layers = n_context_layers
827
+ self.weight_act = weight_act
828
+ self.relations = relations
829
+
830
+ self.classifier = RobertaClassificationHead(config)
831
+
832
+ @property
833
+ def scaler(self):
834
+ return self._scaler.exp()
835
+
836
+ @property
837
+ def rel_weight(self):
838
+ if self.weight_act == 'sigmoid':
839
+ return torch.sigmoid(self._rel_weight)
840
+ elif self.weight_act == 'softmax':
841
+ return torch.softmax(self._rel_weight, dim=-1)
842
+
843
+ def parse(self, x, pos, embeds=None):
844
+ """Parse input sentence.
845
+ Args:
846
+ x: input tokens (required).
847
+ pos: position for each token (optional).
848
+ Returns:
849
+ distance: syntactic distance
850
+ height: syntactic height
851
+ """
852
+
853
+ mask = (x != self.pad)
854
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
855
+
856
+
857
+ if embeds is not None:
858
+ h = embeds
859
+ else:
860
+ h = self.emb(x)
861
+
862
+ for i in range(self.n_parse_layers):
863
+ h = h.masked_fill(~mask[:, :, None], 0)
864
+ h = self.parser_layers[i](h)
865
+
866
+ height = self.height_ff(h).squeeze(-1)
867
+ height.masked_fill_(~mask, -1e9)
868
+
869
+ distance = self.distance_ff(h).squeeze(-1)
870
+ distance.masked_fill_(~mask_shifted, 1e9)
871
+
872
+ # Calbrating the distance and height to the same level
873
+ length = distance.size(1)
874
+ height_max = height[:, None, :].expand(-1, length, -1)
875
+ height_max = torch.cummax(
876
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
877
+ dim=-1)[0].triu(0)
878
+
879
+ margin_left = torch.relu(
880
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
881
+ margin_right = torch.relu(distance[:, None, :] - height_max)
882
+ margin = torch.where(margin_left > margin_right, margin_right,
883
+ margin_left).triu(0)
884
+
885
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
886
+ margin.masked_fill_(~margin_mask, 0)
887
+ margin = margin.max()
888
+
889
+ distance = distance - margin
890
+
891
+ return distance, height
892
+
893
+ def compute_block(self, distance, height):
894
+ """Compute constituents from distance and height."""
895
+
896
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
897
+
898
+ gamma = torch.sigmoid(-beta_logits)
899
+ ones = torch.ones_like(gamma)
900
+
901
+ block_mask_left = cummin(
902
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
903
+ block_mask_left = block_mask_left - F.pad(
904
+ block_mask_left[:, :, :-1], (1, 0), value=0)
905
+ block_mask_left.tril_(0)
906
+
907
+ block_mask_right = cummin(
908
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
909
+ block_mask_right = block_mask_right - F.pad(
910
+ block_mask_right[:, :, 1:], (0, 1), value=0)
911
+ block_mask_right.triu_(0)
912
+
913
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
914
+ block = cumsum(block_mask_left).tril(0) + cumsum(
915
+ block_mask_right, reverse=True).triu(1)
916
+
917
+ return block_p, block
918
+
919
+ def compute_head(self, height):
920
+ """Estimate head for each constituent."""
921
+
922
+ _, length = height.size()
923
+ head_logits = height * self.scaler[1]
924
+ index = torch.arange(length, device=height.device)
925
+
926
+ mask = (index[:, None, None] <= index[None, None, :]) * (
927
+ index[None, None, :] <= index[None, :, None])
928
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
929
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
930
+
931
+ head_p = torch.softmax(head_logits, dim=-1)
932
+
933
+ return head_p
934
+
935
+ def generate_mask(self, x, distance, height):
936
+ """Compute head and cibling distribution for each token."""
937
+
938
+ bsz, length = x.size()
939
+
940
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
941
+ eye = eye[None, :, :].expand((bsz, -1, -1))
942
+
943
+ block_p, block = self.compute_block(distance, height)
944
+ head_p = self.compute_head(height)
945
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
946
+ head = head.masked_fill(eye, 0)
947
+ child = head.transpose(1, 2)
948
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
949
+
950
+ rel_list = []
951
+ if 'head' in self.relations:
952
+ rel_list.append(head)
953
+ if 'child' in self.relations:
954
+ rel_list.append(child)
955
+ if 'cibling' in self.relations:
956
+ rel_list.append(cibling)
957
+
958
+ rel = torch.stack(rel_list, dim=1)
959
+
960
+ rel_weight = self.rel_weight
961
+
962
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
963
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
964
+
965
+ return att_mask, cibling, head, block
966
+
967
+ def encode(self, x, pos, att_mask=None, context_layers=False):
968
+ """Structformer encoding process."""
969
+
970
+ if context_layers:
971
+ """Standard transformer encode process."""
972
+ h = self.emb(x)
973
+ if hasattr(self, 'pos_emb'):
974
+ h = h + self.pos_emb(pos)
975
+ h_list = []
976
+ visibility = self.visibility(x, x.device)
977
+ for i in range(self.n_context_layers):
978
+ h_list.append(h)
979
+ h = self.context_layers[i](
980
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
981
+
982
+ output = h
983
+ h_array = torch.stack(h_list, dim=2)
984
+ return output
985
+
986
+ else:
987
+ visibility = self.visibility(x, x.device)
988
+ h = self.emb(x)
989
+ if hasattr(self, 'pos_emb'):
990
+ assert pos.max() < 500
991
+ h = h + self.pos_emb(pos)
992
+ for i in range(self.nlayers):
993
+ h = self.layers[i](
994
+ h.transpose(0, 1), attn_mask=att_mask[i],
995
+ key_padding_mask=visibility).transpose(0, 1)
996
+ return h
997
+
998
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
999
+
1000
+ x = input_ids
1001
+ batch_size, length = x.size()
1002
+
1003
+ if position_ids is None:
1004
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
1005
+
1006
+ context_layers_output = None
1007
+ if self.n_context_layers > 0:
1008
+ context_layers_output = self.encode(x, pos, context_layers=True)
1009
+
1010
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
1011
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
1012
+
1013
+ raw_output = self.encode(x, pos, att_mask)
1014
+ raw_output = self.norm(raw_output)
1015
+ raw_output = self.drop(raw_output)
1016
+
1017
+ #output = self.output_layer(raw_output)
1018
+ logits = self.classifier(raw_output)
1019
+
1020
+ loss = None
1021
+ if labels is not None:
1022
+ if self.config.problem_type is None:
1023
+ if self.num_labels == 1:
1024
+ self.config.problem_type = "regression"
1025
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1026
+ self.config.problem_type = "single_label_classification"
1027
+ else:
1028
+ self.config.problem_type = "multi_label_classification"
1029
+
1030
+ if self.config.problem_type == "regression":
1031
+ loss_fct = MSELoss()
1032
+ if self.num_labels == 1:
1033
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1034
+ else:
1035
+ loss = loss_fct(logits, labels)
1036
+ elif self.config.problem_type == "single_label_classification":
1037
+ loss_fct = CrossEntropyLoss()
1038
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1039
+ elif self.config.problem_type == "multi_label_classification":
1040
+ loss_fct = BCEWithLogitsLoss()
1041
+ loss = loss_fct(logits, labels)
1042
+
1043
+
1044
+ return SequenceClassifierOutput(
1045
+ loss=loss,
1046
+ logits=logits,
1047
+ hidden_states=None,
1048
+ attentions=None,
1049
+ )
1050
+
1051
+
1052
+
1053
+ ##########################################
1054
+ # HuggingFace Model
1055
+ ##########################################
1056
+ class StructformerModel(PreTrainedModel):
1057
+ config_class = StructformerConfig
1058
+
1059
+ def __init__(self, config):
1060
+ super().__init__(config)
1061
+ self.model = StructFormer(
1062
+ hidden_size=config.hidden_size,
1063
+ n_context_layers=config.n_context_layers,
1064
+ nlayers=config.nlayers,
1065
+ ntokens=config.ntokens,
1066
+ nhead=config.nhead,
1067
+ dropout=config.dropout,
1068
+ dropatt=config.dropatt,
1069
+ relative_bias=config.relative_bias,
1070
+ pos_emb=config.pos_emb,
1071
+ pad=config.pad,
1072
+ n_parser_layers=config.n_parser_layers,
1073
+ conv_size=config.conv_size,
1074
+ relations=config.relations,
1075
+ weight_act=config.weight_act
1076
+ )
1077
+
1078
+ def forward(self, input_ids, labels=None, **kwargs):
1079
+ return self.model(input_ids, labels=labels, **kwargs)
1080
+
1081
+
1082
+
1083
+ class StructformerModelForSequenceClassification(PreTrainedModel):
1084
+ config_class = StructformerConfig
1085
+ def __init__(self, config):
1086
+ super().__init__(config)
1087
+ self.model = StructFormerClassification(
1088
+ hidden_size=config.hidden_size,
1089
+ n_context_layers=config.n_context_layers,
1090
+ nlayers=config.nlayers,
1091
+ ntokens=config.ntokens,
1092
+ nhead=config.nhead,
1093
+ dropout=config.dropout,
1094
+ dropatt=config.dropatt,
1095
+ relative_bias=config.relative_bias,
1096
+ pos_emb=config.pos_emb,
1097
+ pad=config.pad,
1098
+ n_parser_layers=config.n_parser_layers,
1099
+ conv_size=config.conv_size,
1100
+ relations=config.relations,
1101
+ weight_act=config.weight_act,
1102
+ config=config)
1103
+
1104
+ def _init_weights(self, module):
1105
+ """Initialize the weights"""
1106
+ if isinstance(module, nn.Linear):
1107
+ # Slightly different from the TF version which uses truncated_normal for initialization
1108
+ # cf https://github.com/pytorch/pytorch/pull/5617
1109
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1110
+ if module.bias is not None:
1111
+ module.bias.data.zero_()
1112
+ elif isinstance(module, nn.Embedding):
1113
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1114
+ if module.padding_idx is not None:
1115
+ module.weight.data[module.padding_idx].zero_()
1116
+ elif isinstance(module, nn.LayerNorm):
1117
+ if module.bias is not None:
1118
+ module.bias.data.zero_()
1119
+ module.weight.data.fill_(1.0)
1120
+
1121
+
1122
+ def forward(self, input_ids, labels=None, **kwargs):
1123
+ return self.model(input_ids, labels=labels, **kwargs)
finetune/cola/checkpoint-400/tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "cls_token": {
12
+ "__type": "AddedToken",
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "__type": "AddedToken",
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "errors": "replace",
28
+ "mask_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<mask>",
31
+ "lstrip": true,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ },
36
+ "model_max_length": 512,
37
+ "name_or_path": "omarmomen/structformer_s1_final_with_pos",
38
+ "pad_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<pad>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ },
46
+ "sep_token": {
47
+ "__type": "AddedToken",
48
+ "content": "</s>",
49
+ "lstrip": false,
50
+ "normalized": true,
51
+ "rstrip": false,
52
+ "single_word": false
53
+ },
54
+ "special_tokens_map_file": null,
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "trim_offsets": true,
57
+ "unk_token": {
58
+ "__type": "AddedToken",
59
+ "content": "<unk>",
60
+ "lstrip": false,
61
+ "normalized": true,
62
+ "rstrip": false,
63
+ "single_word": false
64
+ }
65
+ }
finetune/cola/checkpoint-400/trainer_state.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.7931937172774869,
3
+ "best_model_checkpoint": "finetune_results/omarmomen/structformer_s1_final_with_pos/cola/checkpoint-400",
4
+ "epoch": 5.797101449275362,
5
+ "global_step": 400,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 5.8,
12
+ "eval_accuracy": 0.689892053604126,
13
+ "eval_f1": 0.7931937172774869,
14
+ "eval_loss": 0.762876570224762,
15
+ "eval_mcc": 0.19495355846277174,
16
+ "eval_runtime": 1.9776,
17
+ "eval_samples_per_second": 515.265,
18
+ "eval_steps_per_second": 64.724,
19
+ "step": 400
20
+ }
21
+ ],
22
+ "max_steps": 690,
23
+ "num_train_epochs": 10,
24
+ "total_flos": 3958322433669120.0,
25
+ "trial_name": null,
26
+ "trial_params": null
27
+ }
finetune/cola/checkpoint-400/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f4559102ecfa4341cceea0c32d0db1fe5185ac1bd8ad00af1c57c8dfaa12f80
3
+ size 3503
finetune/cola/checkpoint-400/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
finetune/cola/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "omarmomen/structformer_s1_final_with_pos",
3
+ "architectures": [
4
+ "StructformerModelForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "structformer_as_hf.StructformerConfig",
9
+ "AutoModelForMaskedLM": "structformer_as_hf.StructformerModel",
10
+ "AutoModelForSequenceClassification": "structformer_as_hf.StructformerModelForSequenceClassification"
11
+ },
12
+ "bos_token_id": 0,
13
+ "classifier_dropout": null,
14
+ "conv_size": 9,
15
+ "dropatt": 0.1,
16
+ "dropout": 0.1,
17
+ "eos_token_id": 2,
18
+ "hidden_act": "gelu",
19
+ "hidden_dropout_prob": 0.1,
20
+ "hidden_size": 768,
21
+ "id2label": {
22
+ "0": 0,
23
+ "1": 1
24
+ },
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 3072,
27
+ "label2id": {
28
+ "0": 0,
29
+ "1": 1
30
+ },
31
+ "layer_norm_eps": 1e-05,
32
+ "max_position_embeddings": 514,
33
+ "model_type": "structformer",
34
+ "n_context_layers": 0,
35
+ "n_parser_layers": 4,
36
+ "nhead": 12,
37
+ "nlayers": 12,
38
+ "ntokens": 32000,
39
+ "num_attention_heads": 8,
40
+ "num_hidden_layers": 8,
41
+ "pad": 0,
42
+ "pad_token_id": 1,
43
+ "pos_emb": true,
44
+ "position_embedding_type": "absolute",
45
+ "problem_type": "single_label_classification",
46
+ "relations": [
47
+ "head",
48
+ "child"
49
+ ],
50
+ "relative_bias": false,
51
+ "torch_dtype": "float32",
52
+ "transformers_version": "4.26.1",
53
+ "type_vocab_size": 1,
54
+ "use_cache": true,
55
+ "vocab_size": 32000,
56
+ "weight_act": "softmax"
57
+ }
finetune/cola/eval_results.json ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 10.0,
3
+ "eval_accuracy": 0.689892053604126,
4
+ "eval_f1": 0.7931937172774869,
5
+ "eval_loss": 0.762876570224762,
6
+ "eval_mcc": 0.19495355846277174,
7
+ "eval_runtime": 1.9662,
8
+ "eval_samples": 1019,
9
+ "eval_samples_per_second": 518.247,
10
+ "eval_steps_per_second": 65.099
11
+ }
finetune/cola/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
finetune/cola/predict_results.txt ADDED
@@ -0,0 +1,1020 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ index prediction
2
+ 0 1
3
+ 1 1
4
+ 2 1
5
+ 3 1
6
+ 4 0
7
+ 5 1
8
+ 6 1
9
+ 7 1
10
+ 8 1
11
+ 9 1
12
+ 10 1
13
+ 11 1
14
+ 12 1
15
+ 13 0
16
+ 14 1
17
+ 15 1
18
+ 16 1
19
+ 17 1
20
+ 18 1
21
+ 19 1
22
+ 20 0
23
+ 21 1
24
+ 22 1
25
+ 23 0
26
+ 24 1
27
+ 25 1
28
+ 26 1
29
+ 27 1
30
+ 28 0
31
+ 29 1
32
+ 30 1
33
+ 31 1
34
+ 32 1
35
+ 33 1
36
+ 34 0
37
+ 35 1
38
+ 36 1
39
+ 37 0
40
+ 38 1
41
+ 39 1
42
+ 40 0
43
+ 41 0
44
+ 42 1
45
+ 43 1
46
+ 44 0
47
+ 45 0
48
+ 46 1
49
+ 47 0
50
+ 48 1
51
+ 49 0
52
+ 50 1
53
+ 51 1
54
+ 52 1
55
+ 53 0
56
+ 54 0
57
+ 55 1
58
+ 56 1
59
+ 57 1
60
+ 58 1
61
+ 59 0
62
+ 60 0
63
+ 61 1
64
+ 62 0
65
+ 63 1
66
+ 64 1
67
+ 65 0
68
+ 66 1
69
+ 67 1
70
+ 68 1
71
+ 69 1
72
+ 70 0
73
+ 71 1
74
+ 72 1
75
+ 73 0
76
+ 74 0
77
+ 75 0
78
+ 76 1
79
+ 77 1
80
+ 78 1
81
+ 79 1
82
+ 80 0
83
+ 81 1
84
+ 82 1
85
+ 83 1
86
+ 84 1
87
+ 85 0
88
+ 86 0
89
+ 87 1
90
+ 88 1
91
+ 89 1
92
+ 90 1
93
+ 91 0
94
+ 92 1
95
+ 93 1
96
+ 94 1
97
+ 95 1
98
+ 96 1
99
+ 97 1
100
+ 98 1
101
+ 99 1
102
+ 100 0
103
+ 101 1
104
+ 102 1
105
+ 103 1
106
+ 104 1
107
+ 105 1
108
+ 106 1
109
+ 107 1
110
+ 108 1
111
+ 109 1
112
+ 110 1
113
+ 111 1
114
+ 112 1
115
+ 113 1
116
+ 114 1
117
+ 115 0
118
+ 116 1
119
+ 117 1
120
+ 118 0
121
+ 119 1
122
+ 120 1
123
+ 121 1
124
+ 122 1
125
+ 123 1
126
+ 124 1
127
+ 125 0
128
+ 126 1
129
+ 127 0
130
+ 128 0
131
+ 129 1
132
+ 130 1
133
+ 131 1
134
+ 132 1
135
+ 133 1
136
+ 134 1
137
+ 135 1
138
+ 136 1
139
+ 137 1
140
+ 138 1
141
+ 139 1
142
+ 140 1
143
+ 141 0
144
+ 142 1
145
+ 143 1
146
+ 144 0
147
+ 145 1
148
+ 146 1
149
+ 147 1
150
+ 148 1
151
+ 149 1
152
+ 150 0
153
+ 151 0
154
+ 152 1
155
+ 153 0
156
+ 154 1
157
+ 155 1
158
+ 156 1
159
+ 157 0
160
+ 158 1
161
+ 159 0
162
+ 160 1
163
+ 161 1
164
+ 162 1
165
+ 163 1
166
+ 164 1
167
+ 165 1
168
+ 166 1
169
+ 167 1
170
+ 168 0
171
+ 169 1
172
+ 170 1
173
+ 171 0
174
+ 172 1
175
+ 173 1
176
+ 174 1
177
+ 175 0
178
+ 176 1
179
+ 177 1
180
+ 178 0
181
+ 179 1
182
+ 180 0
183
+ 181 1
184
+ 182 1
185
+ 183 1
186
+ 184 1
187
+ 185 1
188
+ 186 1
189
+ 187 0
190
+ 188 1
191
+ 189 1
192
+ 190 0
193
+ 191 1
194
+ 192 1
195
+ 193 1
196
+ 194 1
197
+ 195 0
198
+ 196 1
199
+ 197 1
200
+ 198 0
201
+ 199 1
202
+ 200 1
203
+ 201 1
204
+ 202 1
205
+ 203 1
206
+ 204 1
207
+ 205 1
208
+ 206 1
209
+ 207 1
210
+ 208 1
211
+ 209 1
212
+ 210 1
213
+ 211 1
214
+ 212 1
215
+ 213 1
216
+ 214 1
217
+ 215 1
218
+ 216 0
219
+ 217 1
220
+ 218 1
221
+ 219 1
222
+ 220 1
223
+ 221 1
224
+ 222 1
225
+ 223 1
226
+ 224 1
227
+ 225 1
228
+ 226 1
229
+ 227 1
230
+ 228 1
231
+ 229 1
232
+ 230 1
233
+ 231 1
234
+ 232 0
235
+ 233 1
236
+ 234 1
237
+ 235 1
238
+ 236 1
239
+ 237 1
240
+ 238 1
241
+ 239 1
242
+ 240 1
243
+ 241 1
244
+ 242 0
245
+ 243 0
246
+ 244 0
247
+ 245 1
248
+ 246 1
249
+ 247 1
250
+ 248 0
251
+ 249 1
252
+ 250 1
253
+ 251 1
254
+ 252 1
255
+ 253 1
256
+ 254 0
257
+ 255 1
258
+ 256 1
259
+ 257 0
260
+ 258 1
261
+ 259 1
262
+ 260 1
263
+ 261 1
264
+ 262 1
265
+ 263 0
266
+ 264 0
267
+ 265 1
268
+ 266 1
269
+ 267 0
270
+ 268 1
271
+ 269 0
272
+ 270 1
273
+ 271 1
274
+ 272 1
275
+ 273 0
276
+ 274 1
277
+ 275 1
278
+ 276 0
279
+ 277 1
280
+ 278 1
281
+ 279 0
282
+ 280 1
283
+ 281 1
284
+ 282 1
285
+ 283 1
286
+ 284 1
287
+ 285 1
288
+ 286 0
289
+ 287 1
290
+ 288 1
291
+ 289 0
292
+ 290 1
293
+ 291 1
294
+ 292 0
295
+ 293 0
296
+ 294 1
297
+ 295 1
298
+ 296 1
299
+ 297 1
300
+ 298 1
301
+ 299 1
302
+ 300 1
303
+ 301 1
304
+ 302 1
305
+ 303 1
306
+ 304 1
307
+ 305 1
308
+ 306 1
309
+ 307 1
310
+ 308 1
311
+ 309 1
312
+ 310 1
313
+ 311 1
314
+ 312 1
315
+ 313 0
316
+ 314 1
317
+ 315 1
318
+ 316 1
319
+ 317 1
320
+ 318 1
321
+ 319 1
322
+ 320 1
323
+ 321 1
324
+ 322 0
325
+ 323 0
326
+ 324 1
327
+ 325 1
328
+ 326 1
329
+ 327 1
330
+ 328 1
331
+ 329 1
332
+ 330 1
333
+ 331 1
334
+ 332 1
335
+ 333 1
336
+ 334 1
337
+ 335 1
338
+ 336 1
339
+ 337 1
340
+ 338 0
341
+ 339 1
342
+ 340 1
343
+ 341 1
344
+ 342 1
345
+ 343 1
346
+ 344 1
347
+ 345 0
348
+ 346 1
349
+ 347 0
350
+ 348 0
351
+ 349 1
352
+ 350 1
353
+ 351 1
354
+ 352 1
355
+ 353 1
356
+ 354 1
357
+ 355 1
358
+ 356 0
359
+ 357 1
360
+ 358 1
361
+ 359 1
362
+ 360 1
363
+ 361 1
364
+ 362 1
365
+ 363 1
366
+ 364 1
367
+ 365 1
368
+ 366 0
369
+ 367 1
370
+ 368 1
371
+ 369 1
372
+ 370 1
373
+ 371 1
374
+ 372 1
375
+ 373 1
376
+ 374 1
377
+ 375 1
378
+ 376 1
379
+ 377 1
380
+ 378 1
381
+ 379 1
382
+ 380 1
383
+ 381 1
384
+ 382 1
385
+ 383 1
386
+ 384 1
387
+ 385 1
388
+ 386 1
389
+ 387 1
390
+ 388 1
391
+ 389 1
392
+ 390 1
393
+ 391 1
394
+ 392 1
395
+ 393 1
396
+ 394 1
397
+ 395 1
398
+ 396 1
399
+ 397 1
400
+ 398 1
401
+ 399 1
402
+ 400 1
403
+ 401 1
404
+ 402 0
405
+ 403 1
406
+ 404 1
407
+ 405 1
408
+ 406 1
409
+ 407 1
410
+ 408 1
411
+ 409 0
412
+ 410 1
413
+ 411 1
414
+ 412 1
415
+ 413 1
416
+ 414 1
417
+ 415 1
418
+ 416 1
419
+ 417 1
420
+ 418 1
421
+ 419 1
422
+ 420 1
423
+ 421 1
424
+ 422 1
425
+ 423 1
426
+ 424 1
427
+ 425 1
428
+ 426 1
429
+ 427 1
430
+ 428 0
431
+ 429 1
432
+ 430 0
433
+ 431 1
434
+ 432 1
435
+ 433 1
436
+ 434 1
437
+ 435 1
438
+ 436 1
439
+ 437 1
440
+ 438 1
441
+ 439 1
442
+ 440 1
443
+ 441 1
444
+ 442 1
445
+ 443 0
446
+ 444 1
447
+ 445 0
448
+ 446 0
449
+ 447 1
450
+ 448 1
451
+ 449 1
452
+ 450 1
453
+ 451 0
454
+ 452 0
455
+ 453 1
456
+ 454 1
457
+ 455 0
458
+ 456 1
459
+ 457 1
460
+ 458 1
461
+ 459 1
462
+ 460 1
463
+ 461 1
464
+ 462 0
465
+ 463 1
466
+ 464 0
467
+ 465 1
468
+ 466 0
469
+ 467 0
470
+ 468 1
471
+ 469 1
472
+ 470 1
473
+ 471 1
474
+ 472 1
475
+ 473 0
476
+ 474 1
477
+ 475 1
478
+ 476 0
479
+ 477 1
480
+ 478 0
481
+ 479 1
482
+ 480 1
483
+ 481 1
484
+ 482 1
485
+ 483 0
486
+ 484 0
487
+ 485 1
488
+ 486 0
489
+ 487 1
490
+ 488 0
491
+ 489 1
492
+ 490 1
493
+ 491 1
494
+ 492 1
495
+ 493 0
496
+ 494 1
497
+ 495 1
498
+ 496 1
499
+ 497 0
500
+ 498 1
501
+ 499 1
502
+ 500 1
503
+ 501 1
504
+ 502 1
505
+ 503 0
506
+ 504 0
507
+ 505 0
508
+ 506 0
509
+ 507 1
510
+ 508 0
511
+ 509 0
512
+ 510 1
513
+ 511 1
514
+ 512 1
515
+ 513 0
516
+ 514 1
517
+ 515 1
518
+ 516 0
519
+ 517 0
520
+ 518 1
521
+ 519 1
522
+ 520 0
523
+ 521 1
524
+ 522 1
525
+ 523 1
526
+ 524 1
527
+ 525 1
528
+ 526 0
529
+ 527 1
530
+ 528 1
531
+ 529 1
532
+ 530 1
533
+ 531 1
534
+ 532 1
535
+ 533 1
536
+ 534 1
537
+ 535 1
538
+ 536 1
539
+ 537 1
540
+ 538 1
541
+ 539 1
542
+ 540 1
543
+ 541 1
544
+ 542 1
545
+ 543 1
546
+ 544 1
547
+ 545 1
548
+ 546 1
549
+ 547 1
550
+ 548 1
551
+ 549 1
552
+ 550 1
553
+ 551 1
554
+ 552 1
555
+ 553 1
556
+ 554 1
557
+ 555 1
558
+ 556 1
559
+ 557 1
560
+ 558 1
561
+ 559 1
562
+ 560 1
563
+ 561 1
564
+ 562 0
565
+ 563 1
566
+ 564 1
567
+ 565 1
568
+ 566 1
569
+ 567 1
570
+ 568 1
571
+ 569 0
572
+ 570 0
573
+ 571 1
574
+ 572 1
575
+ 573 1
576
+ 574 1
577
+ 575 1
578
+ 576 1
579
+ 577 1
580
+ 578 1
581
+ 579 1
582
+ 580 1
583
+ 581 1
584
+ 582 1
585
+ 583 1
586
+ 584 1
587
+ 585 1
588
+ 586 1
589
+ 587 1
590
+ 588 0
591
+ 589 1
592
+ 590 0
593
+ 591 1
594
+ 592 1
595
+ 593 1
596
+ 594 1
597
+ 595 1
598
+ 596 1
599
+ 597 0
600
+ 598 1
601
+ 599 0
602
+ 600 0
603
+ 601 1
604
+ 602 0
605
+ 603 1
606
+ 604 1
607
+ 605 0
608
+ 606 1
609
+ 607 1
610
+ 608 1
611
+ 609 1
612
+ 610 0
613
+ 611 1
614
+ 612 1
615
+ 613 1
616
+ 614 1
617
+ 615 1
618
+ 616 0
619
+ 617 1
620
+ 618 1
621
+ 619 1
622
+ 620 1
623
+ 621 0
624
+ 622 0
625
+ 623 1
626
+ 624 1
627
+ 625 1
628
+ 626 1
629
+ 627 1
630
+ 628 1
631
+ 629 1
632
+ 630 1
633
+ 631 1
634
+ 632 1
635
+ 633 1
636
+ 634 0
637
+ 635 1
638
+ 636 0
639
+ 637 1
640
+ 638 1
641
+ 639 1
642
+ 640 0
643
+ 641 1
644
+ 642 0
645
+ 643 1
646
+ 644 1
647
+ 645 1
648
+ 646 1
649
+ 647 1
650
+ 648 1
651
+ 649 1
652
+ 650 1
653
+ 651 0
654
+ 652 0
655
+ 653 1
656
+ 654 0
657
+ 655 1
658
+ 656 1
659
+ 657 1
660
+ 658 1
661
+ 659 1
662
+ 660 1
663
+ 661 1
664
+ 662 1
665
+ 663 1
666
+ 664 1
667
+ 665 0
668
+ 666 1
669
+ 667 1
670
+ 668 1
671
+ 669 1
672
+ 670 1
673
+ 671 0
674
+ 672 1
675
+ 673 0
676
+ 674 1
677
+ 675 1
678
+ 676 0
679
+ 677 1
680
+ 678 1
681
+ 679 1
682
+ 680 1
683
+ 681 0
684
+ 682 0
685
+ 683 1
686
+ 684 1
687
+ 685 0
688
+ 686 0
689
+ 687 1
690
+ 688 0
691
+ 689 1
692
+ 690 1
693
+ 691 1
694
+ 692 1
695
+ 693 1
696
+ 694 1
697
+ 695 1
698
+ 696 1
699
+ 697 0
700
+ 698 0
701
+ 699 1
702
+ 700 1
703
+ 701 1
704
+ 702 1
705
+ 703 1
706
+ 704 1
707
+ 705 1
708
+ 706 1
709
+ 707 1
710
+ 708 0
711
+ 709 1
712
+ 710 1
713
+ 711 1
714
+ 712 1
715
+ 713 1
716
+ 714 1
717
+ 715 1
718
+ 716 1
719
+ 717 1
720
+ 718 0
721
+ 719 1
722
+ 720 0
723
+ 721 1
724
+ 722 1
725
+ 723 1
726
+ 724 0
727
+ 725 0
728
+ 726 1
729
+ 727 1
730
+ 728 1
731
+ 729 1
732
+ 730 1
733
+ 731 1
734
+ 732 0
735
+ 733 0
736
+ 734 0
737
+ 735 1
738
+ 736 1
739
+ 737 0
740
+ 738 1
741
+ 739 1
742
+ 740 1
743
+ 741 1
744
+ 742 1
745
+ 743 1
746
+ 744 1
747
+ 745 1
748
+ 746 1
749
+ 747 1
750
+ 748 1
751
+ 749 1
752
+ 750 1
753
+ 751 1
754
+ 752 1
755
+ 753 1
756
+ 754 1
757
+ 755 1
758
+ 756 1
759
+ 757 1
760
+ 758 1
761
+ 759 1
762
+ 760 1
763
+ 761 1
764
+ 762 1
765
+ 763 1
766
+ 764 1
767
+ 765 0
768
+ 766 1
769
+ 767 1
770
+ 768 1
771
+ 769 1
772
+ 770 1
773
+ 771 1
774
+ 772 1
775
+ 773 1
776
+ 774 1
777
+ 775 1
778
+ 776 1
779
+ 777 1
780
+ 778 1
781
+ 779 1
782
+ 780 1
783
+ 781 1
784
+ 782 1
785
+ 783 1
786
+ 784 0
787
+ 785 1
788
+ 786 0
789
+ 787 1
790
+ 788 1
791
+ 789 1
792
+ 790 1
793
+ 791 1
794
+ 792 1
795
+ 793 1
796
+ 794 0
797
+ 795 1
798
+ 796 0
799
+ 797 1
800
+ 798 1
801
+ 799 1
802
+ 800 1
803
+ 801 1
804
+ 802 1
805
+ 803 1
806
+ 804 1
807
+ 805 1
808
+ 806 1
809
+ 807 1
810
+ 808 1
811
+ 809 1
812
+ 810 1
813
+ 811 1
814
+ 812 0
815
+ 813 0
816
+ 814 0
817
+ 815 0
818
+ 816 0
819
+ 817 1
820
+ 818 1
821
+ 819 1
822
+ 820 1
823
+ 821 1
824
+ 822 1
825
+ 823 1
826
+ 824 1
827
+ 825 1
828
+ 826 1
829
+ 827 1
830
+ 828 1
831
+ 829 1
832
+ 830 1
833
+ 831 1
834
+ 832 1
835
+ 833 0
836
+ 834 1
837
+ 835 1
838
+ 836 0
839
+ 837 1
840
+ 838 1
841
+ 839 0
842
+ 840 1
843
+ 841 1
844
+ 842 1
845
+ 843 1
846
+ 844 0
847
+ 845 0
848
+ 846 1
849
+ 847 1
850
+ 848 1
851
+ 849 1
852
+ 850 1
853
+ 851 1
854
+ 852 1
855
+ 853 1
856
+ 854 1
857
+ 855 1
858
+ 856 1
859
+ 857 1
860
+ 858 1
861
+ 859 1
862
+ 860 1
863
+ 861 1
864
+ 862 1
865
+ 863 1
866
+ 864 0
867
+ 865 1
868
+ 866 1
869
+ 867 0
870
+ 868 1
871
+ 869 1
872
+ 870 1
873
+ 871 1
874
+ 872 0
875
+ 873 1
876
+ 874 1
877
+ 875 1
878
+ 876 1
879
+ 877 1
880
+ 878 1
881
+ 879 1
882
+ 880 1
883
+ 881 1
884
+ 882 1
885
+ 883 1
886
+ 884 1
887
+ 885 1
888
+ 886 0
889
+ 887 1
890
+ 888 1
891
+ 889 1
892
+ 890 1
893
+ 891 1
894
+ 892 1
895
+ 893 1
896
+ 894 1
897
+ 895 1
898
+ 896 1
899
+ 897 0
900
+ 898 1
901
+ 899 1
902
+ 900 1
903
+ 901 1
904
+ 902 1
905
+ 903 1
906
+ 904 1
907
+ 905 1
908
+ 906 0
909
+ 907 1
910
+ 908 1
911
+ 909 1
912
+ 910 0
913
+ 911 1
914
+ 912 1
915
+ 913 1
916
+ 914 1
917
+ 915 1
918
+ 916 1
919
+ 917 1
920
+ 918 1
921
+ 919 1
922
+ 920 1
923
+ 921 1
924
+ 922 1
925
+ 923 1
926
+ 924 1
927
+ 925 1
928
+ 926 1
929
+ 927 1
930
+ 928 1
931
+ 929 1
932
+ 930 1
933
+ 931 1
934
+ 932 1
935
+ 933 1
936
+ 934 1
937
+ 935 1
938
+ 936 1
939
+ 937 1
940
+ 938 1
941
+ 939 0
942
+ 940 1
943
+ 941 1
944
+ 942 1
945
+ 943 1
946
+ 944 1
947
+ 945 1
948
+ 946 1
949
+ 947 0
950
+ 948 0
951
+ 949 1
952
+ 950 0
953
+ 951 0
954
+ 952 1
955
+ 953 1
956
+ 954 0
957
+ 955 1
958
+ 956 0
959
+ 957 1
960
+ 958 1
961
+ 959 1
962
+ 960 0
963
+ 961 1
964
+ 962 0
965
+ 963 1
966
+ 964 1
967
+ 965 1
968
+ 966 1
969
+ 967 0
970
+ 968 1
971
+ 969 1
972
+ 970 1
973
+ 971 0
974
+ 972 1
975
+ 973 1
976
+ 974 1
977
+ 975 1
978
+ 976 1
979
+ 977 1
980
+ 978 0
981
+ 979 1
982
+ 980 1
983
+ 981 1
984
+ 982 1
985
+ 983 1
986
+ 984 1
987
+ 985 0
988
+ 986 1
989
+ 987 1
990
+ 988 1
991
+ 989 0
992
+ 990 0
993
+ 991 1
994
+ 992 0
995
+ 993 1
996
+ 994 1
997
+ 995 1
998
+ 996 0
999
+ 997 1
1000
+ 998 1
1001
+ 999 1
1002
+ 1000 1
1003
+ 1001 0
1004
+ 1002 0
1005
+ 1003 0
1006
+ 1004 1
1007
+ 1005 1
1008
+ 1006 0
1009
+ 1007 0
1010
+ 1008 0
1011
+ 1009 1
1012
+ 1010 1
1013
+ 1011 0
1014
+ 1012 1
1015
+ 1013 1
1016
+ 1014 1
1017
+ 1015 1
1018
+ 1016 1
1019
+ 1017 0
1020
+ 1018 1
finetune/cola/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8bc16b9a152d3c83b6da1a5a69be258541741700af78a1cb388eaf13d5424ed4
3
+ size 534669003
finetune/cola/special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
finetune/cola/structformer_as_hf.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from transformers import PreTrainedModel
6
+ from transformers import PretrainedConfig
7
+ from transformers.modeling_outputs import MaskedLMOutput
8
+ from typing import List
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ MaskedLMOutput,
14
+ SequenceClassifierOutput
15
+ )
16
+
17
+ ##########################################
18
+ # HuggingFace Config
19
+ ##########################################
20
+ class StructformerConfig(PretrainedConfig):
21
+ model_type = "structformer"
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size=768,
26
+ n_context_layers=2,
27
+ nlayers=6,
28
+ ntokens=32000,
29
+ nhead=8,
30
+ dropout=0.1,
31
+ dropatt=0.1,
32
+ relative_bias=False,
33
+ pos_emb=False,
34
+ pad=0,
35
+ n_parser_layers=4,
36
+ conv_size=9,
37
+ relations=('head', 'child'),
38
+ weight_act='softmax',
39
+ **kwargs,
40
+ ):
41
+ self.hidden_size = hidden_size
42
+ self.n_context_layers = n_context_layers
43
+ self.nlayers = nlayers
44
+ self.ntokens = ntokens
45
+ self.nhead = nhead
46
+ self.dropout = dropout
47
+ self.dropatt = dropatt
48
+ self.relative_bias = relative_bias
49
+ self.pos_emb = pos_emb
50
+ self.pad = pad
51
+ self.n_parser_layers = n_parser_layers
52
+ self.conv_size = conv_size
53
+ self.relations = relations
54
+ self.weight_act = weight_act
55
+ super().__init__(**kwargs)
56
+
57
+ ##########################################
58
+ # Custom Layers
59
+ ##########################################
60
+ def _get_activation_fn(activation):
61
+ """Get specified activation function."""
62
+ if activation == "relu":
63
+ return nn.ReLU()
64
+ elif activation == "gelu":
65
+ return nn.GELU()
66
+ elif activation == "leakyrelu":
67
+ return nn.LeakyReLU()
68
+
69
+ raise RuntimeError(
70
+ "activation should be relu/gelu, not {}".format(activation))
71
+
72
+ class Conv1d(nn.Module):
73
+ """1D convolution layer."""
74
+
75
+ def __init__(self, hidden_size, kernel_size, dilation=1):
76
+ """Initialization.
77
+ Args:
78
+ hidden_size: dimension of input embeddings
79
+ kernel_size: convolution kernel size
80
+ dilation: the spacing between the kernel points
81
+ """
82
+ super(Conv1d, self).__init__()
83
+
84
+ if kernel_size % 2 == 0:
85
+ padding = (kernel_size // 2) * dilation
86
+ self.shift = True
87
+ else:
88
+ padding = ((kernel_size - 1) // 2) * dilation
89
+ self.shift = False
90
+ self.conv = nn.Conv1d(
91
+ hidden_size,
92
+ hidden_size,
93
+ kernel_size,
94
+ padding=padding,
95
+ dilation=dilation)
96
+
97
+ def forward(self, x):
98
+ """Compute convolution.
99
+ Args:
100
+ x: input embeddings
101
+ Returns:
102
+ conv_output: convolution results
103
+ """
104
+
105
+ if self.shift:
106
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
107
+ else:
108
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
109
+
110
+ class MultiheadAttention(nn.Module):
111
+ """Multi-head self-attention layer."""
112
+
113
+ def __init__(self,
114
+ embed_dim,
115
+ num_heads,
116
+ dropout=0.,
117
+ bias=True,
118
+ v_proj=True,
119
+ out_proj=True,
120
+ relative_bias=True):
121
+ """Initialization.
122
+ Args:
123
+ embed_dim: dimension of input embeddings
124
+ num_heads: number of self-attention heads
125
+ dropout: dropout rate
126
+ bias: bool, indicate whether include bias for linear transformations
127
+ v_proj: bool, indicate whether project inputs to new values
128
+ out_proj: bool, indicate whether project outputs to new values
129
+ relative_bias: bool, indicate whether use a relative position based
130
+ attention bias
131
+ """
132
+
133
+ super(MultiheadAttention, self).__init__()
134
+ self.embed_dim = embed_dim
135
+
136
+ self.num_heads = num_heads
137
+ self.drop = nn.Dropout(dropout)
138
+ self.head_dim = embed_dim // num_heads
139
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
140
+ "divisible by "
141
+ "num_heads")
142
+
143
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
144
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
145
+ if v_proj:
146
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
147
+ else:
148
+ self.v_proj = nn.Identity()
149
+
150
+ if out_proj:
151
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
152
+ else:
153
+ self.out_proj = nn.Identity()
154
+
155
+ if relative_bias:
156
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
157
+ else:
158
+ self.relative_bias = None
159
+
160
+ self._reset_parameters()
161
+
162
+ def _reset_parameters(self):
163
+ """Initialize attention parameters."""
164
+
165
+ init.xavier_uniform_(self.q_proj.weight)
166
+ init.constant_(self.q_proj.bias, 0.)
167
+
168
+ init.xavier_uniform_(self.k_proj.weight)
169
+ init.constant_(self.k_proj.bias, 0.)
170
+
171
+ if isinstance(self.v_proj, nn.Linear):
172
+ init.xavier_uniform_(self.v_proj.weight)
173
+ init.constant_(self.v_proj.bias, 0.)
174
+
175
+ if isinstance(self.out_proj, nn.Linear):
176
+ init.xavier_uniform_(self.out_proj.weight)
177
+ init.constant_(self.out_proj.bias, 0.)
178
+
179
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
180
+ """Compute multi-head self-attention.
181
+ Args:
182
+ query: input embeddings
183
+ key_padding_mask: 3D mask that prevents attention to certain positions
184
+ attn_mask: 3D mask that rescale the attention weight at each position
185
+ Returns:
186
+ attn_output: self-attention output
187
+ """
188
+
189
+ length, bsz, embed_dim = query.size()
190
+ assert embed_dim == self.embed_dim
191
+
192
+ head_dim = embed_dim // self.num_heads
193
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
194
+ "divisible by num_heads")
195
+ scaling = float(head_dim)**-0.5
196
+
197
+ q = self.q_proj(query)
198
+ k = self.k_proj(query)
199
+ v = self.v_proj(query)
200
+
201
+ q = q * scaling
202
+
203
+ if attn_mask is not None:
204
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
205
+ query.size(0), query.size(0)]
206
+
207
+ q = q.contiguous().view(length, bsz * self.num_heads,
208
+ head_dim).transpose(0, 1)
209
+ k = k.contiguous().view(length, bsz * self.num_heads,
210
+ head_dim).transpose(0, 1)
211
+ v = v.contiguous().view(length, bsz * self.num_heads,
212
+ head_dim).transpose(0, 1)
213
+
214
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
215
+ assert list(
216
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
217
+
218
+ if self.relative_bias is not None:
219
+ pos = torch.arange(length, device=query.device)
220
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
221
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
222
+ -1)
223
+
224
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
225
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
226
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
227
+ attn_output_weights = attn_output_weights + relative_bias
228
+
229
+ if key_padding_mask is not None:
230
+ attn_output_weights = attn_output_weights + key_padding_mask
231
+
232
+ if attn_mask is None:
233
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
234
+ else:
235
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
236
+
237
+ attn_output_weights = self.drop(attn_output_weights)
238
+
239
+ attn_output = torch.bmm(attn_output_weights, v)
240
+
241
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
242
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
243
+ length, bsz, embed_dim)
244
+ attn_output = self.out_proj(attn_output)
245
+
246
+ return attn_output
247
+
248
+ class TransformerLayer(nn.Module):
249
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
250
+
251
+ def __init__(self,
252
+ d_model,
253
+ nhead,
254
+ dim_feedforward=2048,
255
+ dropout=0.1,
256
+ dropatt=0.1,
257
+ activation="leakyrelu",
258
+ relative_bias=True):
259
+ """Initialization.
260
+ Args:
261
+ d_model: dimension of inputs
262
+ nhead: number of self-attention heads
263
+ dim_feedforward: dimension of hidden layer in feedforward layer
264
+ dropout: dropout rate
265
+ dropatt: drop attention rate
266
+ activation: activation function
267
+ relative_bias: bool, indicate whether use a relative position based
268
+ attention bias
269
+ """
270
+
271
+ super(TransformerLayer, self).__init__()
272
+
273
+ self.self_attn = MultiheadAttention(
274
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
275
+
276
+ # Implementation of Feedforward model
277
+ self.feedforward = nn.Sequential(
278
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
279
+ _get_activation_fn(activation), nn.Dropout(dropout),
280
+ nn.Linear(dim_feedforward, d_model))
281
+
282
+ self.norm = nn.LayerNorm(d_model)
283
+ self.dropout1 = nn.Dropout(dropout)
284
+ self.dropout2 = nn.Dropout(dropout)
285
+
286
+ self.nhead = nhead
287
+
288
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
289
+ """Pass the input through the encoder layer.
290
+ Args:
291
+ src: the sequence to the encoder layer (required).
292
+ attn_mask: the mask for the src sequence (optional).
293
+ key_padding_mask: the mask for the src keys per batch (optional).
294
+ Returns:
295
+ src3: the output of transformer layer, share the same shape as src.
296
+ """
297
+ src2 = self.self_attn(
298
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
299
+ src2 = src + self.dropout1(src2)
300
+ src3 = self.feedforward(src2)
301
+ src3 = src2 + self.dropout2(src3)
302
+
303
+ return src3
304
+
305
+
306
+
307
+ class RobertaClassificationHead(nn.Module):
308
+ """Head for sentence-level classification tasks."""
309
+
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
313
+ classifier_dropout = (
314
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
315
+ )
316
+ self.dropout = nn.Dropout(classifier_dropout)
317
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
318
+
319
+ def forward(self, features, **kwargs):
320
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
321
+ x = self.dropout(x)
322
+ x = self.dense(x)
323
+ x = torch.tanh(x)
324
+ x = self.dropout(x)
325
+ x = self.out_proj(x)
326
+ return x
327
+
328
+
329
+ ##########################################
330
+ # Custom Models
331
+ ##########################################
332
+ def cumprod(x, reverse=False, exclusive=False):
333
+ """cumulative product."""
334
+ if reverse:
335
+ x = x.flip([-1])
336
+
337
+ if exclusive:
338
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
339
+
340
+ cx = x.cumprod(-1)
341
+
342
+ if reverse:
343
+ cx = cx.flip([-1])
344
+ return cx
345
+
346
+ def cumsum(x, reverse=False, exclusive=False):
347
+ """cumulative sum."""
348
+ bsz, _, length = x.size()
349
+ device = x.device
350
+ if reverse:
351
+ if exclusive:
352
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
353
+ else:
354
+ w = torch.ones([bsz, length, length], device=device).tril(0)
355
+ cx = torch.bmm(x, w)
356
+ else:
357
+ if exclusive:
358
+ w = torch.ones([bsz, length, length], device=device).triu(1)
359
+ else:
360
+ w = torch.ones([bsz, length, length], device=device).triu(0)
361
+ cx = torch.bmm(x, w)
362
+ return cx
363
+
364
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
365
+ """cumulative min."""
366
+ if reverse:
367
+ if exclusive:
368
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
369
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
370
+ else:
371
+ if exclusive:
372
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
373
+ x = x.cummin(-1)[0]
374
+ return x
375
+
376
+ class Transformer(nn.Module):
377
+ """Transformer model."""
378
+
379
+ def __init__(self,
380
+ hidden_size,
381
+ nlayers,
382
+ ntokens,
383
+ nhead=8,
384
+ dropout=0.1,
385
+ dropatt=0.1,
386
+ relative_bias=True,
387
+ pos_emb=False,
388
+ pad=0):
389
+ """Initialization.
390
+ Args:
391
+ hidden_size: dimension of inputs and hidden states
392
+ nlayers: number of layers
393
+ ntokens: number of output categories
394
+ nhead: number of self-attention heads
395
+ dropout: dropout rate
396
+ dropatt: drop attention rate
397
+ relative_bias: bool, indicate whether use a relative position based
398
+ attention bias
399
+ pos_emb: bool, indicate whether use a learnable positional embedding
400
+ pad: pad token index
401
+ """
402
+
403
+ super(Transformer, self).__init__()
404
+
405
+ self.drop = nn.Dropout(dropout)
406
+
407
+ self.emb = nn.Embedding(ntokens, hidden_size)
408
+ if pos_emb:
409
+ self.pos_emb = nn.Embedding(500, hidden_size)
410
+
411
+ self.layers = nn.ModuleList([
412
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
413
+ dropatt=dropatt, relative_bias=relative_bias)
414
+ for _ in range(nlayers)])
415
+
416
+ self.norm = nn.LayerNorm(hidden_size)
417
+
418
+ self.output_layer = nn.Linear(hidden_size, ntokens)
419
+ self.output_layer.weight = self.emb.weight
420
+
421
+ self.init_weights()
422
+
423
+ self.nlayers = nlayers
424
+ self.nhead = nhead
425
+ self.ntokens = ntokens
426
+ self.hidden_size = hidden_size
427
+ self.pad = pad
428
+
429
+ def init_weights(self):
430
+ """Initialize token embedding and output bias."""
431
+ initrange = 0.1
432
+ self.emb.weight.data.uniform_(-initrange, initrange)
433
+ if hasattr(self, 'pos_emb'):
434
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
435
+ self.output_layer.bias.data.fill_(0)
436
+
437
+ def visibility(self, x, device):
438
+ """Mask pad tokens."""
439
+ visibility = (x != self.pad).float()
440
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
441
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
442
+ return visibility.log()
443
+
444
+ def encode(self, x, pos):
445
+ """Standard transformer encode process."""
446
+ h = self.emb(x)
447
+ if hasattr(self, 'pos_emb'):
448
+ h = h + self.pos_emb(pos)
449
+ h_list = []
450
+ visibility = self.visibility(x, x.device)
451
+
452
+ for i in range(self.nlayers):
453
+ h_list.append(h)
454
+ h = self.layers[i](
455
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
456
+
457
+ output = h
458
+ h_array = torch.stack(h_list, dim=2)
459
+
460
+ return output, h_array
461
+
462
+ def forward(self, x, pos):
463
+ """Pass the input through the encoder layer.
464
+ Args:
465
+ x: input tokens (required).
466
+ pos: position for each token (optional).
467
+ Returns:
468
+ output: probability distributions for missing tokens.
469
+ state_dict: parsing results and raw output
470
+ """
471
+
472
+ batch_size, length = x.size()
473
+
474
+ raw_output, _ = self.encode(x, pos)
475
+ raw_output = self.norm(raw_output)
476
+ raw_output = self.drop(raw_output)
477
+
478
+ output = self.output_layer(raw_output)
479
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
480
+
481
+ class StructFormer(Transformer):
482
+ """StructFormer model."""
483
+
484
+ def __init__(self,
485
+ hidden_size,
486
+ n_context_layers,
487
+ nlayers,
488
+ ntokens,
489
+ nhead=8,
490
+ dropout=0.1,
491
+ dropatt=0.1,
492
+ relative_bias=False,
493
+ pos_emb=False,
494
+ pad=0,
495
+ n_parser_layers=4,
496
+ conv_size=9,
497
+ relations=('head', 'child'),
498
+ weight_act='softmax'):
499
+ """Initialization.
500
+ Args:
501
+ hidden_size: dimension of inputs and hidden states
502
+ nlayers: number of layers
503
+ ntokens: number of output categories
504
+ nhead: number of self-attention heads
505
+ dropout: dropout rate
506
+ dropatt: drop attention rate
507
+ relative_bias: bool, indicate whether use a relative position based
508
+ attention bias
509
+ pos_emb: bool, indicate whether use a learnable positional embedding
510
+ pad: pad token index
511
+ n_parser_layers: number of parsing layers
512
+ conv_size: convolution kernel size for parser
513
+ relations: relations that are used to compute self attention
514
+ weight_act: relations distribution activation function
515
+ """
516
+
517
+ super(StructFormer, self).__init__(
518
+ hidden_size,
519
+ nlayers,
520
+ ntokens,
521
+ nhead=nhead,
522
+ dropout=dropout,
523
+ dropatt=dropatt,
524
+ relative_bias=relative_bias,
525
+ pos_emb=pos_emb,
526
+ pad=pad)
527
+
528
+ if n_context_layers > 0:
529
+ self.context_layers = nn.ModuleList([
530
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
531
+ dropatt=dropatt, relative_bias=relative_bias)
532
+ for _ in range(n_context_layers)])
533
+
534
+ self.parser_layers = nn.ModuleList([
535
+ nn.Sequential(Conv1d(hidden_size, conv_size),
536
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
537
+ nn.Tanh()) for i in range(n_parser_layers)])
538
+
539
+ self.distance_ff = nn.Sequential(
540
+ Conv1d(hidden_size, 2),
541
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
542
+ nn.Linear(hidden_size, 1))
543
+
544
+ self.height_ff = nn.Sequential(
545
+ nn.Linear(hidden_size, hidden_size),
546
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
547
+ nn.Linear(hidden_size, 1))
548
+
549
+ n_rel = len(relations)
550
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
551
+ self._rel_weight.data.normal_(0, 0.1)
552
+
553
+ self._scaler = nn.Parameter(torch.zeros(2))
554
+
555
+ self.n_parse_layers = n_parser_layers
556
+ self.n_context_layers = n_context_layers
557
+ self.weight_act = weight_act
558
+ self.relations = relations
559
+
560
+ @property
561
+ def scaler(self):
562
+ return self._scaler.exp()
563
+
564
+ @property
565
+ def rel_weight(self):
566
+ if self.weight_act == 'sigmoid':
567
+ return torch.sigmoid(self._rel_weight)
568
+ elif self.weight_act == 'softmax':
569
+ return torch.softmax(self._rel_weight, dim=-1)
570
+
571
+ def parse(self, x, pos, embeds=None):
572
+ """Parse input sentence.
573
+ Args:
574
+ x: input tokens (required).
575
+ pos: position for each token (optional).
576
+ Returns:
577
+ distance: syntactic distance
578
+ height: syntactic height
579
+ """
580
+
581
+ mask = (x != self.pad)
582
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
583
+
584
+
585
+ if embeds is not None:
586
+ h = embeds
587
+ else:
588
+ h = self.emb(x)
589
+
590
+ for i in range(self.n_parse_layers):
591
+ h = h.masked_fill(~mask[:, :, None], 0)
592
+ h = self.parser_layers[i](h)
593
+
594
+ height = self.height_ff(h).squeeze(-1)
595
+ height.masked_fill_(~mask, -1e9)
596
+
597
+ distance = self.distance_ff(h).squeeze(-1)
598
+ distance.masked_fill_(~mask_shifted, 1e9)
599
+
600
+ # Calbrating the distance and height to the same level
601
+ length = distance.size(1)
602
+ height_max = height[:, None, :].expand(-1, length, -1)
603
+ height_max = torch.cummax(
604
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
605
+ dim=-1)[0].triu(0)
606
+
607
+ margin_left = torch.relu(
608
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
609
+ margin_right = torch.relu(distance[:, None, :] - height_max)
610
+ margin = torch.where(margin_left > margin_right, margin_right,
611
+ margin_left).triu(0)
612
+
613
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
614
+ margin.masked_fill_(~margin_mask, 0)
615
+ margin = margin.max()
616
+
617
+ distance = distance - margin
618
+
619
+ return distance, height
620
+
621
+ def compute_block(self, distance, height):
622
+ """Compute constituents from distance and height."""
623
+
624
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
625
+
626
+ gamma = torch.sigmoid(-beta_logits)
627
+ ones = torch.ones_like(gamma)
628
+
629
+ block_mask_left = cummin(
630
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
631
+ block_mask_left = block_mask_left - F.pad(
632
+ block_mask_left[:, :, :-1], (1, 0), value=0)
633
+ block_mask_left.tril_(0)
634
+
635
+ block_mask_right = cummin(
636
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
637
+ block_mask_right = block_mask_right - F.pad(
638
+ block_mask_right[:, :, 1:], (0, 1), value=0)
639
+ block_mask_right.triu_(0)
640
+
641
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
642
+ block = cumsum(block_mask_left).tril(0) + cumsum(
643
+ block_mask_right, reverse=True).triu(1)
644
+
645
+ return block_p, block
646
+
647
+ def compute_head(self, height):
648
+ """Estimate head for each constituent."""
649
+
650
+ _, length = height.size()
651
+ head_logits = height * self.scaler[1]
652
+ index = torch.arange(length, device=height.device)
653
+
654
+ mask = (index[:, None, None] <= index[None, None, :]) * (
655
+ index[None, None, :] <= index[None, :, None])
656
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
657
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
658
+
659
+ head_p = torch.softmax(head_logits, dim=-1)
660
+
661
+ return head_p
662
+
663
+ def generate_mask(self, x, distance, height):
664
+ """Compute head and cibling distribution for each token."""
665
+
666
+ bsz, length = x.size()
667
+
668
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
669
+ eye = eye[None, :, :].expand((bsz, -1, -1))
670
+
671
+ block_p, block = self.compute_block(distance, height)
672
+ head_p = self.compute_head(height)
673
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
674
+ head = head.masked_fill(eye, 0)
675
+ child = head.transpose(1, 2)
676
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
677
+
678
+ rel_list = []
679
+ if 'head' in self.relations:
680
+ rel_list.append(head)
681
+ if 'child' in self.relations:
682
+ rel_list.append(child)
683
+ if 'cibling' in self.relations:
684
+ rel_list.append(cibling)
685
+
686
+ rel = torch.stack(rel_list, dim=1)
687
+
688
+ rel_weight = self.rel_weight
689
+
690
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
691
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
692
+
693
+ return att_mask, cibling, head, block
694
+
695
+ def encode(self, x, pos, att_mask=None, context_layers=False):
696
+ """Structformer encoding process."""
697
+
698
+ if context_layers:
699
+ """Standard transformer encode process."""
700
+ h = self.emb(x)
701
+ if hasattr(self, 'pos_emb'):
702
+ h = h + self.pos_emb(pos)
703
+ h_list = []
704
+ visibility = self.visibility(x, x.device)
705
+ for i in range(self.n_context_layers):
706
+ h_list.append(h)
707
+ h = self.context_layers[i](
708
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
709
+
710
+ output = h
711
+ h_array = torch.stack(h_list, dim=2)
712
+ return output
713
+
714
+ else:
715
+ visibility = self.visibility(x, x.device)
716
+ h = self.emb(x)
717
+ if hasattr(self, 'pos_emb'):
718
+ assert pos.max() < 500
719
+ h = h + self.pos_emb(pos)
720
+ for i in range(self.nlayers):
721
+ h = self.layers[i](
722
+ h.transpose(0, 1), attn_mask=att_mask[i],
723
+ key_padding_mask=visibility).transpose(0, 1)
724
+ return h
725
+
726
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
727
+
728
+ x = input_ids
729
+ batch_size, length = x.size()
730
+
731
+ if position_ids is None:
732
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
733
+
734
+ context_layers_output = None
735
+ if self.n_context_layers > 0:
736
+ context_layers_output = self.encode(x, pos, context_layers=True)
737
+
738
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
739
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
740
+
741
+ raw_output = self.encode(x, pos, att_mask)
742
+ raw_output = self.norm(raw_output)
743
+ raw_output = self.drop(raw_output)
744
+
745
+ output = self.output_layer(raw_output)
746
+
747
+ loss = None
748
+ if labels is not None:
749
+ loss_fct = nn.CrossEntropyLoss()
750
+ loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1))
751
+
752
+ return MaskedLMOutput(
753
+ loss=loss, # shape: 1
754
+ logits=output, # shape: (batch_size * length, ntokens)
755
+ hidden_states=None,
756
+ attentions=None,
757
+ )
758
+
759
+
760
+
761
+
762
+ class StructFormerClassification(Transformer):
763
+ """StructFormer model."""
764
+
765
+ def __init__(self,
766
+ hidden_size,
767
+ n_context_layers,
768
+ nlayers,
769
+ ntokens,
770
+ nhead=8,
771
+ dropout=0.1,
772
+ dropatt=0.1,
773
+ relative_bias=False,
774
+ pos_emb=False,
775
+ pad=0,
776
+ n_parser_layers=4,
777
+ conv_size=9,
778
+ relations=('head', 'child'),
779
+ weight_act='softmax',
780
+ config=None,
781
+ ):
782
+
783
+
784
+ super(StructFormerClassification, self).__init__(
785
+ hidden_size,
786
+ nlayers,
787
+ ntokens,
788
+ nhead=nhead,
789
+ dropout=dropout,
790
+ dropatt=dropatt,
791
+ relative_bias=relative_bias,
792
+ pos_emb=pos_emb,
793
+ pad=pad)
794
+
795
+ self.num_labels = config.num_labels
796
+ self.config = config
797
+
798
+ if n_context_layers > 0:
799
+ self.context_layers = nn.ModuleList([
800
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
801
+ dropatt=dropatt, relative_bias=relative_bias)
802
+ for _ in range(n_context_layers)])
803
+
804
+ self.parser_layers = nn.ModuleList([
805
+ nn.Sequential(Conv1d(hidden_size, conv_size),
806
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
807
+ nn.Tanh()) for i in range(n_parser_layers)])
808
+
809
+ self.distance_ff = nn.Sequential(
810
+ Conv1d(hidden_size, 2),
811
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
812
+ nn.Linear(hidden_size, 1))
813
+
814
+ self.height_ff = nn.Sequential(
815
+ nn.Linear(hidden_size, hidden_size),
816
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
817
+ nn.Linear(hidden_size, 1))
818
+
819
+ n_rel = len(relations)
820
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
821
+ self._rel_weight.data.normal_(0, 0.1)
822
+
823
+ self._scaler = nn.Parameter(torch.zeros(2))
824
+
825
+ self.n_parse_layers = n_parser_layers
826
+ self.n_context_layers = n_context_layers
827
+ self.weight_act = weight_act
828
+ self.relations = relations
829
+
830
+ self.classifier = RobertaClassificationHead(config)
831
+
832
+ @property
833
+ def scaler(self):
834
+ return self._scaler.exp()
835
+
836
+ @property
837
+ def rel_weight(self):
838
+ if self.weight_act == 'sigmoid':
839
+ return torch.sigmoid(self._rel_weight)
840
+ elif self.weight_act == 'softmax':
841
+ return torch.softmax(self._rel_weight, dim=-1)
842
+
843
+ def parse(self, x, pos, embeds=None):
844
+ """Parse input sentence.
845
+ Args:
846
+ x: input tokens (required).
847
+ pos: position for each token (optional).
848
+ Returns:
849
+ distance: syntactic distance
850
+ height: syntactic height
851
+ """
852
+
853
+ mask = (x != self.pad)
854
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
855
+
856
+
857
+ if embeds is not None:
858
+ h = embeds
859
+ else:
860
+ h = self.emb(x)
861
+
862
+ for i in range(self.n_parse_layers):
863
+ h = h.masked_fill(~mask[:, :, None], 0)
864
+ h = self.parser_layers[i](h)
865
+
866
+ height = self.height_ff(h).squeeze(-1)
867
+ height.masked_fill_(~mask, -1e9)
868
+
869
+ distance = self.distance_ff(h).squeeze(-1)
870
+ distance.masked_fill_(~mask_shifted, 1e9)
871
+
872
+ # Calbrating the distance and height to the same level
873
+ length = distance.size(1)
874
+ height_max = height[:, None, :].expand(-1, length, -1)
875
+ height_max = torch.cummax(
876
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
877
+ dim=-1)[0].triu(0)
878
+
879
+ margin_left = torch.relu(
880
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
881
+ margin_right = torch.relu(distance[:, None, :] - height_max)
882
+ margin = torch.where(margin_left > margin_right, margin_right,
883
+ margin_left).triu(0)
884
+
885
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
886
+ margin.masked_fill_(~margin_mask, 0)
887
+ margin = margin.max()
888
+
889
+ distance = distance - margin
890
+
891
+ return distance, height
892
+
893
+ def compute_block(self, distance, height):
894
+ """Compute constituents from distance and height."""
895
+
896
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
897
+
898
+ gamma = torch.sigmoid(-beta_logits)
899
+ ones = torch.ones_like(gamma)
900
+
901
+ block_mask_left = cummin(
902
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
903
+ block_mask_left = block_mask_left - F.pad(
904
+ block_mask_left[:, :, :-1], (1, 0), value=0)
905
+ block_mask_left.tril_(0)
906
+
907
+ block_mask_right = cummin(
908
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
909
+ block_mask_right = block_mask_right - F.pad(
910
+ block_mask_right[:, :, 1:], (0, 1), value=0)
911
+ block_mask_right.triu_(0)
912
+
913
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
914
+ block = cumsum(block_mask_left).tril(0) + cumsum(
915
+ block_mask_right, reverse=True).triu(1)
916
+
917
+ return block_p, block
918
+
919
+ def compute_head(self, height):
920
+ """Estimate head for each constituent."""
921
+
922
+ _, length = height.size()
923
+ head_logits = height * self.scaler[1]
924
+ index = torch.arange(length, device=height.device)
925
+
926
+ mask = (index[:, None, None] <= index[None, None, :]) * (
927
+ index[None, None, :] <= index[None, :, None])
928
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
929
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
930
+
931
+ head_p = torch.softmax(head_logits, dim=-1)
932
+
933
+ return head_p
934
+
935
+ def generate_mask(self, x, distance, height):
936
+ """Compute head and cibling distribution for each token."""
937
+
938
+ bsz, length = x.size()
939
+
940
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
941
+ eye = eye[None, :, :].expand((bsz, -1, -1))
942
+
943
+ block_p, block = self.compute_block(distance, height)
944
+ head_p = self.compute_head(height)
945
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
946
+ head = head.masked_fill(eye, 0)
947
+ child = head.transpose(1, 2)
948
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
949
+
950
+ rel_list = []
951
+ if 'head' in self.relations:
952
+ rel_list.append(head)
953
+ if 'child' in self.relations:
954
+ rel_list.append(child)
955
+ if 'cibling' in self.relations:
956
+ rel_list.append(cibling)
957
+
958
+ rel = torch.stack(rel_list, dim=1)
959
+
960
+ rel_weight = self.rel_weight
961
+
962
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
963
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
964
+
965
+ return att_mask, cibling, head, block
966
+
967
+ def encode(self, x, pos, att_mask=None, context_layers=False):
968
+ """Structformer encoding process."""
969
+
970
+ if context_layers:
971
+ """Standard transformer encode process."""
972
+ h = self.emb(x)
973
+ if hasattr(self, 'pos_emb'):
974
+ h = h + self.pos_emb(pos)
975
+ h_list = []
976
+ visibility = self.visibility(x, x.device)
977
+ for i in range(self.n_context_layers):
978
+ h_list.append(h)
979
+ h = self.context_layers[i](
980
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
981
+
982
+ output = h
983
+ h_array = torch.stack(h_list, dim=2)
984
+ return output
985
+
986
+ else:
987
+ visibility = self.visibility(x, x.device)
988
+ h = self.emb(x)
989
+ if hasattr(self, 'pos_emb'):
990
+ assert pos.max() < 500
991
+ h = h + self.pos_emb(pos)
992
+ for i in range(self.nlayers):
993
+ h = self.layers[i](
994
+ h.transpose(0, 1), attn_mask=att_mask[i],
995
+ key_padding_mask=visibility).transpose(0, 1)
996
+ return h
997
+
998
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
999
+
1000
+ x = input_ids
1001
+ batch_size, length = x.size()
1002
+
1003
+ if position_ids is None:
1004
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
1005
+
1006
+ context_layers_output = None
1007
+ if self.n_context_layers > 0:
1008
+ context_layers_output = self.encode(x, pos, context_layers=True)
1009
+
1010
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
1011
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
1012
+
1013
+ raw_output = self.encode(x, pos, att_mask)
1014
+ raw_output = self.norm(raw_output)
1015
+ raw_output = self.drop(raw_output)
1016
+
1017
+ #output = self.output_layer(raw_output)
1018
+ logits = self.classifier(raw_output)
1019
+
1020
+ loss = None
1021
+ if labels is not None:
1022
+ if self.config.problem_type is None:
1023
+ if self.num_labels == 1:
1024
+ self.config.problem_type = "regression"
1025
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1026
+ self.config.problem_type = "single_label_classification"
1027
+ else:
1028
+ self.config.problem_type = "multi_label_classification"
1029
+
1030
+ if self.config.problem_type == "regression":
1031
+ loss_fct = MSELoss()
1032
+ if self.num_labels == 1:
1033
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1034
+ else:
1035
+ loss = loss_fct(logits, labels)
1036
+ elif self.config.problem_type == "single_label_classification":
1037
+ loss_fct = CrossEntropyLoss()
1038
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1039
+ elif self.config.problem_type == "multi_label_classification":
1040
+ loss_fct = BCEWithLogitsLoss()
1041
+ loss = loss_fct(logits, labels)
1042
+
1043
+
1044
+ return SequenceClassifierOutput(
1045
+ loss=loss,
1046
+ logits=logits,
1047
+ hidden_states=None,
1048
+ attentions=None,
1049
+ )
1050
+
1051
+
1052
+
1053
+ ##########################################
1054
+ # HuggingFace Model
1055
+ ##########################################
1056
+ class StructformerModel(PreTrainedModel):
1057
+ config_class = StructformerConfig
1058
+
1059
+ def __init__(self, config):
1060
+ super().__init__(config)
1061
+ self.model = StructFormer(
1062
+ hidden_size=config.hidden_size,
1063
+ n_context_layers=config.n_context_layers,
1064
+ nlayers=config.nlayers,
1065
+ ntokens=config.ntokens,
1066
+ nhead=config.nhead,
1067
+ dropout=config.dropout,
1068
+ dropatt=config.dropatt,
1069
+ relative_bias=config.relative_bias,
1070
+ pos_emb=config.pos_emb,
1071
+ pad=config.pad,
1072
+ n_parser_layers=config.n_parser_layers,
1073
+ conv_size=config.conv_size,
1074
+ relations=config.relations,
1075
+ weight_act=config.weight_act
1076
+ )
1077
+
1078
+ def forward(self, input_ids, labels=None, **kwargs):
1079
+ return self.model(input_ids, labels=labels, **kwargs)
1080
+
1081
+
1082
+
1083
+ class StructformerModelForSequenceClassification(PreTrainedModel):
1084
+ config_class = StructformerConfig
1085
+ def __init__(self, config):
1086
+ super().__init__(config)
1087
+ self.model = StructFormerClassification(
1088
+ hidden_size=config.hidden_size,
1089
+ n_context_layers=config.n_context_layers,
1090
+ nlayers=config.nlayers,
1091
+ ntokens=config.ntokens,
1092
+ nhead=config.nhead,
1093
+ dropout=config.dropout,
1094
+ dropatt=config.dropatt,
1095
+ relative_bias=config.relative_bias,
1096
+ pos_emb=config.pos_emb,
1097
+ pad=config.pad,
1098
+ n_parser_layers=config.n_parser_layers,
1099
+ conv_size=config.conv_size,
1100
+ relations=config.relations,
1101
+ weight_act=config.weight_act,
1102
+ config=config)
1103
+
1104
+ def _init_weights(self, module):
1105
+ """Initialize the weights"""
1106
+ if isinstance(module, nn.Linear):
1107
+ # Slightly different from the TF version which uses truncated_normal for initialization
1108
+ # cf https://github.com/pytorch/pytorch/pull/5617
1109
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1110
+ if module.bias is not None:
1111
+ module.bias.data.zero_()
1112
+ elif isinstance(module, nn.Embedding):
1113
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1114
+ if module.padding_idx is not None:
1115
+ module.weight.data[module.padding_idx].zero_()
1116
+ elif isinstance(module, nn.LayerNorm):
1117
+ if module.bias is not None:
1118
+ module.bias.data.zero_()
1119
+ module.weight.data.fill_(1.0)
1120
+
1121
+
1122
+ def forward(self, input_ids, labels=None, **kwargs):
1123
+ return self.model(input_ids, labels=labels, **kwargs)
finetune/cola/tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "cls_token": {
12
+ "__type": "AddedToken",
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "__type": "AddedToken",
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "errors": "replace",
28
+ "mask_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<mask>",
31
+ "lstrip": true,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ },
36
+ "model_max_length": 512,
37
+ "name_or_path": "omarmomen/structformer_s1_final_with_pos",
38
+ "pad_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<pad>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ },
46
+ "sep_token": {
47
+ "__type": "AddedToken",
48
+ "content": "</s>",
49
+ "lstrip": false,
50
+ "normalized": true,
51
+ "rstrip": false,
52
+ "single_word": false
53
+ },
54
+ "special_tokens_map_file": null,
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "trim_offsets": true,
57
+ "unk_token": {
58
+ "__type": "AddedToken",
59
+ "content": "<unk>",
60
+ "lstrip": false,
61
+ "normalized": true,
62
+ "rstrip": false,
63
+ "single_word": false
64
+ }
65
+ }
finetune/cola/train_results.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 10.0,
3
+ "train_loss": 0.3857684093972911,
4
+ "train_runtime": 391.5771,
5
+ "train_samples": 8164,
6
+ "train_samples_per_second": 208.49,
7
+ "train_steps_per_second": 1.762
8
+ }
finetune/cola/trainer_state.json ADDED
@@ -0,0 +1,42 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.7931937172774869,
3
+ "best_model_checkpoint": "finetune_results/omarmomen/structformer_s1_final_with_pos/cola/checkpoint-400",
4
+ "epoch": 10.0,
5
+ "global_step": 690,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 5.8,
12
+ "eval_accuracy": 0.689892053604126,
13
+ "eval_f1": 0.7931937172774869,
14
+ "eval_loss": 0.762876570224762,
15
+ "eval_mcc": 0.19495355846277174,
16
+ "eval_runtime": 1.9776,
17
+ "eval_samples_per_second": 515.265,
18
+ "eval_steps_per_second": 64.724,
19
+ "step": 400
20
+ },
21
+ {
22
+ "epoch": 7.25,
23
+ "learning_rate": 1.3768115942028985e-05,
24
+ "loss": 0.4623,
25
+ "step": 500
26
+ },
27
+ {
28
+ "epoch": 10.0,
29
+ "step": 690,
30
+ "total_flos": 6814792144343040.0,
31
+ "train_loss": 0.3857684093972911,
32
+ "train_runtime": 391.5771,
33
+ "train_samples_per_second": 208.49,
34
+ "train_steps_per_second": 1.762
35
+ }
36
+ ],
37
+ "max_steps": 690,
38
+ "num_train_epochs": 10,
39
+ "total_flos": 6814792144343040.0,
40
+ "trial_name": null,
41
+ "trial_params": null
42
+ }
finetune/cola/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:9f4559102ecfa4341cceea0c32d0db1fe5185ac1bd8ad00af1c57c8dfaa12f80
3
+ size 3503
finetune/cola/vocab.json ADDED
The diff for this file is too large to render. See raw diff
 
finetune/control_raising_control/all_results.json ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "epoch": 10.0,
3
+ "eval_accuracy": 0.7487743496894836,
4
+ "eval_f1": 0.7944072948328267,
5
+ "eval_loss": 2.3855533599853516,
6
+ "eval_mcc": 0.5565748633817942,
7
+ "eval_runtime": 13.0758,
8
+ "eval_samples": 6731,
9
+ "eval_samples_per_second": 514.768,
10
+ "eval_steps_per_second": 64.394,
11
+ "train_loss": 0.08940400845815681,
12
+ "train_runtime": 342.2215,
13
+ "train_samples": 6570,
14
+ "train_samples_per_second": 191.981,
15
+ "train_steps_per_second": 1.607
16
+ }
finetune/control_raising_control/checkpoint-400/config.json ADDED
@@ -0,0 +1,57 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "omarmomen/structformer_s1_final_with_pos",
3
+ "architectures": [
4
+ "StructformerModelForSequenceClassification"
5
+ ],
6
+ "attention_probs_dropout_prob": 0.1,
7
+ "auto_map": {
8
+ "AutoConfig": "structformer_as_hf.StructformerConfig",
9
+ "AutoModelForMaskedLM": "structformer_as_hf.StructformerModel",
10
+ "AutoModelForSequenceClassification": "structformer_as_hf.StructformerModelForSequenceClassification"
11
+ },
12
+ "bos_token_id": 0,
13
+ "classifier_dropout": null,
14
+ "conv_size": 9,
15
+ "dropatt": 0.1,
16
+ "dropout": 0.1,
17
+ "eos_token_id": 2,
18
+ "hidden_act": "gelu",
19
+ "hidden_dropout_prob": 0.1,
20
+ "hidden_size": 768,
21
+ "id2label": {
22
+ "0": 0,
23
+ "1": 1
24
+ },
25
+ "initializer_range": 0.02,
26
+ "intermediate_size": 3072,
27
+ "label2id": {
28
+ "0": 0,
29
+ "1": 1
30
+ },
31
+ "layer_norm_eps": 1e-05,
32
+ "max_position_embeddings": 514,
33
+ "model_type": "structformer",
34
+ "n_context_layers": 0,
35
+ "n_parser_layers": 4,
36
+ "nhead": 12,
37
+ "nlayers": 12,
38
+ "ntokens": 32000,
39
+ "num_attention_heads": 8,
40
+ "num_hidden_layers": 8,
41
+ "pad": 0,
42
+ "pad_token_id": 1,
43
+ "pos_emb": true,
44
+ "position_embedding_type": "absolute",
45
+ "problem_type": "single_label_classification",
46
+ "relations": [
47
+ "head",
48
+ "child"
49
+ ],
50
+ "relative_bias": false,
51
+ "torch_dtype": "float32",
52
+ "transformers_version": "4.26.1",
53
+ "type_vocab_size": 1,
54
+ "use_cache": true,
55
+ "vocab_size": 32000,
56
+ "weight_act": "softmax"
57
+ }
finetune/control_raising_control/checkpoint-400/merges.txt ADDED
The diff for this file is too large to render. See raw diff
 
finetune/control_raising_control/checkpoint-400/optimizer.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7de6edd5fe571d34ec0a653350c58cad024bd88053adb63a25dbf64389a4afff
3
+ size 1069068057
finetune/control_raising_control/checkpoint-400/pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:7ebf4182ca1295e18b36927bca97bdfc348d7fa6004f8e520b06007af1861628
3
+ size 534669003
finetune/control_raising_control/checkpoint-400/rng_state.pth ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:02576a68480d1bd556bf2545dbb91a6fcd2d751faf798e19f80054fd04bf34a7
3
+ size 14503
finetune/control_raising_control/checkpoint-400/scheduler.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:0c2403f14641b3caeb1b4d17bf70ec776358494ec9059cbe53a4c9c5a18c4c15
3
+ size 623
finetune/control_raising_control/checkpoint-400/special_tokens_map.json ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "bos_token": "<s>",
3
+ "cls_token": "<s>",
4
+ "eos_token": "</s>",
5
+ "mask_token": {
6
+ "content": "<mask>",
7
+ "lstrip": true,
8
+ "normalized": false,
9
+ "rstrip": false,
10
+ "single_word": false
11
+ },
12
+ "pad_token": "<pad>",
13
+ "sep_token": "</s>",
14
+ "unk_token": "<unk>"
15
+ }
finetune/control_raising_control/checkpoint-400/structformer_as_hf.py ADDED
@@ -0,0 +1,1123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from torch.nn import init
5
+ from transformers import PreTrainedModel
6
+ from transformers import PretrainedConfig
7
+ from transformers.modeling_outputs import MaskedLMOutput
8
+ from typing import List
9
+ from torch.nn import CrossEntropyLoss, MSELoss, BCEWithLogitsLoss
10
+ from transformers.modeling_outputs import (
11
+ BaseModelOutputWithPastAndCrossAttentions,
12
+ BaseModelOutputWithPoolingAndCrossAttentions,
13
+ MaskedLMOutput,
14
+ SequenceClassifierOutput
15
+ )
16
+
17
+ ##########################################
18
+ # HuggingFace Config
19
+ ##########################################
20
+ class StructformerConfig(PretrainedConfig):
21
+ model_type = "structformer"
22
+
23
+ def __init__(
24
+ self,
25
+ hidden_size=768,
26
+ n_context_layers=2,
27
+ nlayers=6,
28
+ ntokens=32000,
29
+ nhead=8,
30
+ dropout=0.1,
31
+ dropatt=0.1,
32
+ relative_bias=False,
33
+ pos_emb=False,
34
+ pad=0,
35
+ n_parser_layers=4,
36
+ conv_size=9,
37
+ relations=('head', 'child'),
38
+ weight_act='softmax',
39
+ **kwargs,
40
+ ):
41
+ self.hidden_size = hidden_size
42
+ self.n_context_layers = n_context_layers
43
+ self.nlayers = nlayers
44
+ self.ntokens = ntokens
45
+ self.nhead = nhead
46
+ self.dropout = dropout
47
+ self.dropatt = dropatt
48
+ self.relative_bias = relative_bias
49
+ self.pos_emb = pos_emb
50
+ self.pad = pad
51
+ self.n_parser_layers = n_parser_layers
52
+ self.conv_size = conv_size
53
+ self.relations = relations
54
+ self.weight_act = weight_act
55
+ super().__init__(**kwargs)
56
+
57
+ ##########################################
58
+ # Custom Layers
59
+ ##########################################
60
+ def _get_activation_fn(activation):
61
+ """Get specified activation function."""
62
+ if activation == "relu":
63
+ return nn.ReLU()
64
+ elif activation == "gelu":
65
+ return nn.GELU()
66
+ elif activation == "leakyrelu":
67
+ return nn.LeakyReLU()
68
+
69
+ raise RuntimeError(
70
+ "activation should be relu/gelu, not {}".format(activation))
71
+
72
+ class Conv1d(nn.Module):
73
+ """1D convolution layer."""
74
+
75
+ def __init__(self, hidden_size, kernel_size, dilation=1):
76
+ """Initialization.
77
+ Args:
78
+ hidden_size: dimension of input embeddings
79
+ kernel_size: convolution kernel size
80
+ dilation: the spacing between the kernel points
81
+ """
82
+ super(Conv1d, self).__init__()
83
+
84
+ if kernel_size % 2 == 0:
85
+ padding = (kernel_size // 2) * dilation
86
+ self.shift = True
87
+ else:
88
+ padding = ((kernel_size - 1) // 2) * dilation
89
+ self.shift = False
90
+ self.conv = nn.Conv1d(
91
+ hidden_size,
92
+ hidden_size,
93
+ kernel_size,
94
+ padding=padding,
95
+ dilation=dilation)
96
+
97
+ def forward(self, x):
98
+ """Compute convolution.
99
+ Args:
100
+ x: input embeddings
101
+ Returns:
102
+ conv_output: convolution results
103
+ """
104
+
105
+ if self.shift:
106
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)[:, 1:]
107
+ else:
108
+ return self.conv(x.transpose(1, 2)).transpose(1, 2)
109
+
110
+ class MultiheadAttention(nn.Module):
111
+ """Multi-head self-attention layer."""
112
+
113
+ def __init__(self,
114
+ embed_dim,
115
+ num_heads,
116
+ dropout=0.,
117
+ bias=True,
118
+ v_proj=True,
119
+ out_proj=True,
120
+ relative_bias=True):
121
+ """Initialization.
122
+ Args:
123
+ embed_dim: dimension of input embeddings
124
+ num_heads: number of self-attention heads
125
+ dropout: dropout rate
126
+ bias: bool, indicate whether include bias for linear transformations
127
+ v_proj: bool, indicate whether project inputs to new values
128
+ out_proj: bool, indicate whether project outputs to new values
129
+ relative_bias: bool, indicate whether use a relative position based
130
+ attention bias
131
+ """
132
+
133
+ super(MultiheadAttention, self).__init__()
134
+ self.embed_dim = embed_dim
135
+
136
+ self.num_heads = num_heads
137
+ self.drop = nn.Dropout(dropout)
138
+ self.head_dim = embed_dim // num_heads
139
+ assert self.head_dim * num_heads == self.embed_dim, ("embed_dim must be "
140
+ "divisible by "
141
+ "num_heads")
142
+
143
+ self.q_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
144
+ self.k_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
145
+ if v_proj:
146
+ self.v_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
147
+ else:
148
+ self.v_proj = nn.Identity()
149
+
150
+ if out_proj:
151
+ self.out_proj = nn.Linear(embed_dim, embed_dim, bias=bias)
152
+ else:
153
+ self.out_proj = nn.Identity()
154
+
155
+ if relative_bias:
156
+ self.relative_bias = nn.Parameter(torch.zeros((self.num_heads, 512)))
157
+ else:
158
+ self.relative_bias = None
159
+
160
+ self._reset_parameters()
161
+
162
+ def _reset_parameters(self):
163
+ """Initialize attention parameters."""
164
+
165
+ init.xavier_uniform_(self.q_proj.weight)
166
+ init.constant_(self.q_proj.bias, 0.)
167
+
168
+ init.xavier_uniform_(self.k_proj.weight)
169
+ init.constant_(self.k_proj.bias, 0.)
170
+
171
+ if isinstance(self.v_proj, nn.Linear):
172
+ init.xavier_uniform_(self.v_proj.weight)
173
+ init.constant_(self.v_proj.bias, 0.)
174
+
175
+ if isinstance(self.out_proj, nn.Linear):
176
+ init.xavier_uniform_(self.out_proj.weight)
177
+ init.constant_(self.out_proj.bias, 0.)
178
+
179
+ def forward(self, query, key_padding_mask=None, attn_mask=None):
180
+ """Compute multi-head self-attention.
181
+ Args:
182
+ query: input embeddings
183
+ key_padding_mask: 3D mask that prevents attention to certain positions
184
+ attn_mask: 3D mask that rescale the attention weight at each position
185
+ Returns:
186
+ attn_output: self-attention output
187
+ """
188
+
189
+ length, bsz, embed_dim = query.size()
190
+ assert embed_dim == self.embed_dim
191
+
192
+ head_dim = embed_dim // self.num_heads
193
+ assert head_dim * self.num_heads == embed_dim, ("embed_dim must be "
194
+ "divisible by num_heads")
195
+ scaling = float(head_dim)**-0.5
196
+
197
+ q = self.q_proj(query)
198
+ k = self.k_proj(query)
199
+ v = self.v_proj(query)
200
+
201
+ q = q * scaling
202
+
203
+ if attn_mask is not None:
204
+ assert list(attn_mask.size()) == [bsz * self.num_heads,
205
+ query.size(0), query.size(0)]
206
+
207
+ q = q.contiguous().view(length, bsz * self.num_heads,
208
+ head_dim).transpose(0, 1)
209
+ k = k.contiguous().view(length, bsz * self.num_heads,
210
+ head_dim).transpose(0, 1)
211
+ v = v.contiguous().view(length, bsz * self.num_heads,
212
+ head_dim).transpose(0, 1)
213
+
214
+ attn_output_weights = torch.bmm(q, k.transpose(1, 2))
215
+ assert list(
216
+ attn_output_weights.size()) == [bsz * self.num_heads, length, length]
217
+
218
+ if self.relative_bias is not None:
219
+ pos = torch.arange(length, device=query.device)
220
+ relative_pos = torch.abs(pos[:, None] - pos[None, :]) + 256
221
+ relative_pos = relative_pos[None, :, :].expand(bsz * self.num_heads, -1,
222
+ -1)
223
+
224
+ relative_bias = self.relative_bias.repeat_interleave(bsz, dim=0)
225
+ relative_bias = relative_bias[:, None, :].expand(-1, length, -1)
226
+ relative_bias = torch.gather(relative_bias, 2, relative_pos)
227
+ attn_output_weights = attn_output_weights + relative_bias
228
+
229
+ if key_padding_mask is not None:
230
+ attn_output_weights = attn_output_weights + key_padding_mask
231
+
232
+ if attn_mask is None:
233
+ attn_output_weights = torch.softmax(attn_output_weights, dim=-1)
234
+ else:
235
+ attn_output_weights = torch.sigmoid(attn_output_weights) * attn_mask
236
+
237
+ attn_output_weights = self.drop(attn_output_weights)
238
+
239
+ attn_output = torch.bmm(attn_output_weights, v)
240
+
241
+ assert list(attn_output.size()) == [bsz * self.num_heads, length, head_dim]
242
+ attn_output = attn_output.transpose(0, 1).contiguous().view(
243
+ length, bsz, embed_dim)
244
+ attn_output = self.out_proj(attn_output)
245
+
246
+ return attn_output
247
+
248
+ class TransformerLayer(nn.Module):
249
+ """TransformerEncoderLayer is made up of self-attn and feedforward network."""
250
+
251
+ def __init__(self,
252
+ d_model,
253
+ nhead,
254
+ dim_feedforward=2048,
255
+ dropout=0.1,
256
+ dropatt=0.1,
257
+ activation="leakyrelu",
258
+ relative_bias=True):
259
+ """Initialization.
260
+ Args:
261
+ d_model: dimension of inputs
262
+ nhead: number of self-attention heads
263
+ dim_feedforward: dimension of hidden layer in feedforward layer
264
+ dropout: dropout rate
265
+ dropatt: drop attention rate
266
+ activation: activation function
267
+ relative_bias: bool, indicate whether use a relative position based
268
+ attention bias
269
+ """
270
+
271
+ super(TransformerLayer, self).__init__()
272
+
273
+ self.self_attn = MultiheadAttention(
274
+ d_model, nhead, dropout=dropatt, relative_bias=relative_bias)
275
+
276
+ # Implementation of Feedforward model
277
+ self.feedforward = nn.Sequential(
278
+ nn.LayerNorm(d_model), nn.Linear(d_model, dim_feedforward),
279
+ _get_activation_fn(activation), nn.Dropout(dropout),
280
+ nn.Linear(dim_feedforward, d_model))
281
+
282
+ self.norm = nn.LayerNorm(d_model)
283
+ self.dropout1 = nn.Dropout(dropout)
284
+ self.dropout2 = nn.Dropout(dropout)
285
+
286
+ self.nhead = nhead
287
+
288
+ def forward(self, src, attn_mask=None, key_padding_mask=None):
289
+ """Pass the input through the encoder layer.
290
+ Args:
291
+ src: the sequence to the encoder layer (required).
292
+ attn_mask: the mask for the src sequence (optional).
293
+ key_padding_mask: the mask for the src keys per batch (optional).
294
+ Returns:
295
+ src3: the output of transformer layer, share the same shape as src.
296
+ """
297
+ src2 = self.self_attn(
298
+ self.norm(src), attn_mask=attn_mask, key_padding_mask=key_padding_mask)
299
+ src2 = src + self.dropout1(src2)
300
+ src3 = self.feedforward(src2)
301
+ src3 = src2 + self.dropout2(src3)
302
+
303
+ return src3
304
+
305
+
306
+
307
+ class RobertaClassificationHead(nn.Module):
308
+ """Head for sentence-level classification tasks."""
309
+
310
+ def __init__(self, config):
311
+ super().__init__()
312
+ self.dense = nn.Linear(config.hidden_size, config.hidden_size)
313
+ classifier_dropout = (
314
+ config.classifier_dropout if config.classifier_dropout is not None else config.hidden_dropout_prob
315
+ )
316
+ self.dropout = nn.Dropout(classifier_dropout)
317
+ self.out_proj = nn.Linear(config.hidden_size, config.num_labels)
318
+
319
+ def forward(self, features, **kwargs):
320
+ x = features[:, 0, :] # take <s> token (equiv. to [CLS])
321
+ x = self.dropout(x)
322
+ x = self.dense(x)
323
+ x = torch.tanh(x)
324
+ x = self.dropout(x)
325
+ x = self.out_proj(x)
326
+ return x
327
+
328
+
329
+ ##########################################
330
+ # Custom Models
331
+ ##########################################
332
+ def cumprod(x, reverse=False, exclusive=False):
333
+ """cumulative product."""
334
+ if reverse:
335
+ x = x.flip([-1])
336
+
337
+ if exclusive:
338
+ x = F.pad(x[:, :, :-1], (1, 0), value=1)
339
+
340
+ cx = x.cumprod(-1)
341
+
342
+ if reverse:
343
+ cx = cx.flip([-1])
344
+ return cx
345
+
346
+ def cumsum(x, reverse=False, exclusive=False):
347
+ """cumulative sum."""
348
+ bsz, _, length = x.size()
349
+ device = x.device
350
+ if reverse:
351
+ if exclusive:
352
+ w = torch.ones([bsz, length, length], device=device).tril(-1)
353
+ else:
354
+ w = torch.ones([bsz, length, length], device=device).tril(0)
355
+ cx = torch.bmm(x, w)
356
+ else:
357
+ if exclusive:
358
+ w = torch.ones([bsz, length, length], device=device).triu(1)
359
+ else:
360
+ w = torch.ones([bsz, length, length], device=device).triu(0)
361
+ cx = torch.bmm(x, w)
362
+ return cx
363
+
364
+ def cummin(x, reverse=False, exclusive=False, max_value=1e9):
365
+ """cumulative min."""
366
+ if reverse:
367
+ if exclusive:
368
+ x = F.pad(x[:, :, 1:], (0, 1), value=max_value)
369
+ x = x.flip([-1]).cummin(-1)[0].flip([-1])
370
+ else:
371
+ if exclusive:
372
+ x = F.pad(x[:, :, :-1], (1, 0), value=max_value)
373
+ x = x.cummin(-1)[0]
374
+ return x
375
+
376
+ class Transformer(nn.Module):
377
+ """Transformer model."""
378
+
379
+ def __init__(self,
380
+ hidden_size,
381
+ nlayers,
382
+ ntokens,
383
+ nhead=8,
384
+ dropout=0.1,
385
+ dropatt=0.1,
386
+ relative_bias=True,
387
+ pos_emb=False,
388
+ pad=0):
389
+ """Initialization.
390
+ Args:
391
+ hidden_size: dimension of inputs and hidden states
392
+ nlayers: number of layers
393
+ ntokens: number of output categories
394
+ nhead: number of self-attention heads
395
+ dropout: dropout rate
396
+ dropatt: drop attention rate
397
+ relative_bias: bool, indicate whether use a relative position based
398
+ attention bias
399
+ pos_emb: bool, indicate whether use a learnable positional embedding
400
+ pad: pad token index
401
+ """
402
+
403
+ super(Transformer, self).__init__()
404
+
405
+ self.drop = nn.Dropout(dropout)
406
+
407
+ self.emb = nn.Embedding(ntokens, hidden_size)
408
+ if pos_emb:
409
+ self.pos_emb = nn.Embedding(500, hidden_size)
410
+
411
+ self.layers = nn.ModuleList([
412
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
413
+ dropatt=dropatt, relative_bias=relative_bias)
414
+ for _ in range(nlayers)])
415
+
416
+ self.norm = nn.LayerNorm(hidden_size)
417
+
418
+ self.output_layer = nn.Linear(hidden_size, ntokens)
419
+ self.output_layer.weight = self.emb.weight
420
+
421
+ self.init_weights()
422
+
423
+ self.nlayers = nlayers
424
+ self.nhead = nhead
425
+ self.ntokens = ntokens
426
+ self.hidden_size = hidden_size
427
+ self.pad = pad
428
+
429
+ def init_weights(self):
430
+ """Initialize token embedding and output bias."""
431
+ initrange = 0.1
432
+ self.emb.weight.data.uniform_(-initrange, initrange)
433
+ if hasattr(self, 'pos_emb'):
434
+ self.pos_emb.weight.data.uniform_(-initrange, initrange)
435
+ self.output_layer.bias.data.fill_(0)
436
+
437
+ def visibility(self, x, device):
438
+ """Mask pad tokens."""
439
+ visibility = (x != self.pad).float()
440
+ visibility = visibility[:, None, :].expand(-1, x.size(1), -1)
441
+ visibility = torch.repeat_interleave(visibility, self.nhead, dim=0)
442
+ return visibility.log()
443
+
444
+ def encode(self, x, pos):
445
+ """Standard transformer encode process."""
446
+ h = self.emb(x)
447
+ if hasattr(self, 'pos_emb'):
448
+ h = h + self.pos_emb(pos)
449
+ h_list = []
450
+ visibility = self.visibility(x, x.device)
451
+
452
+ for i in range(self.nlayers):
453
+ h_list.append(h)
454
+ h = self.layers[i](
455
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
456
+
457
+ output = h
458
+ h_array = torch.stack(h_list, dim=2)
459
+
460
+ return output, h_array
461
+
462
+ def forward(self, x, pos):
463
+ """Pass the input through the encoder layer.
464
+ Args:
465
+ x: input tokens (required).
466
+ pos: position for each token (optional).
467
+ Returns:
468
+ output: probability distributions for missing tokens.
469
+ state_dict: parsing results and raw output
470
+ """
471
+
472
+ batch_size, length = x.size()
473
+
474
+ raw_output, _ = self.encode(x, pos)
475
+ raw_output = self.norm(raw_output)
476
+ raw_output = self.drop(raw_output)
477
+
478
+ output = self.output_layer(raw_output)
479
+ return output.view(batch_size * length, -1), {'raw_output': raw_output,}
480
+
481
+ class StructFormer(Transformer):
482
+ """StructFormer model."""
483
+
484
+ def __init__(self,
485
+ hidden_size,
486
+ n_context_layers,
487
+ nlayers,
488
+ ntokens,
489
+ nhead=8,
490
+ dropout=0.1,
491
+ dropatt=0.1,
492
+ relative_bias=False,
493
+ pos_emb=False,
494
+ pad=0,
495
+ n_parser_layers=4,
496
+ conv_size=9,
497
+ relations=('head', 'child'),
498
+ weight_act='softmax'):
499
+ """Initialization.
500
+ Args:
501
+ hidden_size: dimension of inputs and hidden states
502
+ nlayers: number of layers
503
+ ntokens: number of output categories
504
+ nhead: number of self-attention heads
505
+ dropout: dropout rate
506
+ dropatt: drop attention rate
507
+ relative_bias: bool, indicate whether use a relative position based
508
+ attention bias
509
+ pos_emb: bool, indicate whether use a learnable positional embedding
510
+ pad: pad token index
511
+ n_parser_layers: number of parsing layers
512
+ conv_size: convolution kernel size for parser
513
+ relations: relations that are used to compute self attention
514
+ weight_act: relations distribution activation function
515
+ """
516
+
517
+ super(StructFormer, self).__init__(
518
+ hidden_size,
519
+ nlayers,
520
+ ntokens,
521
+ nhead=nhead,
522
+ dropout=dropout,
523
+ dropatt=dropatt,
524
+ relative_bias=relative_bias,
525
+ pos_emb=pos_emb,
526
+ pad=pad)
527
+
528
+ if n_context_layers > 0:
529
+ self.context_layers = nn.ModuleList([
530
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
531
+ dropatt=dropatt, relative_bias=relative_bias)
532
+ for _ in range(n_context_layers)])
533
+
534
+ self.parser_layers = nn.ModuleList([
535
+ nn.Sequential(Conv1d(hidden_size, conv_size),
536
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
537
+ nn.Tanh()) for i in range(n_parser_layers)])
538
+
539
+ self.distance_ff = nn.Sequential(
540
+ Conv1d(hidden_size, 2),
541
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
542
+ nn.Linear(hidden_size, 1))
543
+
544
+ self.height_ff = nn.Sequential(
545
+ nn.Linear(hidden_size, hidden_size),
546
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
547
+ nn.Linear(hidden_size, 1))
548
+
549
+ n_rel = len(relations)
550
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
551
+ self._rel_weight.data.normal_(0, 0.1)
552
+
553
+ self._scaler = nn.Parameter(torch.zeros(2))
554
+
555
+ self.n_parse_layers = n_parser_layers
556
+ self.n_context_layers = n_context_layers
557
+ self.weight_act = weight_act
558
+ self.relations = relations
559
+
560
+ @property
561
+ def scaler(self):
562
+ return self._scaler.exp()
563
+
564
+ @property
565
+ def rel_weight(self):
566
+ if self.weight_act == 'sigmoid':
567
+ return torch.sigmoid(self._rel_weight)
568
+ elif self.weight_act == 'softmax':
569
+ return torch.softmax(self._rel_weight, dim=-1)
570
+
571
+ def parse(self, x, pos, embeds=None):
572
+ """Parse input sentence.
573
+ Args:
574
+ x: input tokens (required).
575
+ pos: position for each token (optional).
576
+ Returns:
577
+ distance: syntactic distance
578
+ height: syntactic height
579
+ """
580
+
581
+ mask = (x != self.pad)
582
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
583
+
584
+
585
+ if embeds is not None:
586
+ h = embeds
587
+ else:
588
+ h = self.emb(x)
589
+
590
+ for i in range(self.n_parse_layers):
591
+ h = h.masked_fill(~mask[:, :, None], 0)
592
+ h = self.parser_layers[i](h)
593
+
594
+ height = self.height_ff(h).squeeze(-1)
595
+ height.masked_fill_(~mask, -1e9)
596
+
597
+ distance = self.distance_ff(h).squeeze(-1)
598
+ distance.masked_fill_(~mask_shifted, 1e9)
599
+
600
+ # Calbrating the distance and height to the same level
601
+ length = distance.size(1)
602
+ height_max = height[:, None, :].expand(-1, length, -1)
603
+ height_max = torch.cummax(
604
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
605
+ dim=-1)[0].triu(0)
606
+
607
+ margin_left = torch.relu(
608
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
609
+ margin_right = torch.relu(distance[:, None, :] - height_max)
610
+ margin = torch.where(margin_left > margin_right, margin_right,
611
+ margin_left).triu(0)
612
+
613
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
614
+ margin.masked_fill_(~margin_mask, 0)
615
+ margin = margin.max()
616
+
617
+ distance = distance - margin
618
+
619
+ return distance, height
620
+
621
+ def compute_block(self, distance, height):
622
+ """Compute constituents from distance and height."""
623
+
624
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
625
+
626
+ gamma = torch.sigmoid(-beta_logits)
627
+ ones = torch.ones_like(gamma)
628
+
629
+ block_mask_left = cummin(
630
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
631
+ block_mask_left = block_mask_left - F.pad(
632
+ block_mask_left[:, :, :-1], (1, 0), value=0)
633
+ block_mask_left.tril_(0)
634
+
635
+ block_mask_right = cummin(
636
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
637
+ block_mask_right = block_mask_right - F.pad(
638
+ block_mask_right[:, :, 1:], (0, 1), value=0)
639
+ block_mask_right.triu_(0)
640
+
641
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
642
+ block = cumsum(block_mask_left).tril(0) + cumsum(
643
+ block_mask_right, reverse=True).triu(1)
644
+
645
+ return block_p, block
646
+
647
+ def compute_head(self, height):
648
+ """Estimate head for each constituent."""
649
+
650
+ _, length = height.size()
651
+ head_logits = height * self.scaler[1]
652
+ index = torch.arange(length, device=height.device)
653
+
654
+ mask = (index[:, None, None] <= index[None, None, :]) * (
655
+ index[None, None, :] <= index[None, :, None])
656
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
657
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
658
+
659
+ head_p = torch.softmax(head_logits, dim=-1)
660
+
661
+ return head_p
662
+
663
+ def generate_mask(self, x, distance, height):
664
+ """Compute head and cibling distribution for each token."""
665
+
666
+ bsz, length = x.size()
667
+
668
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
669
+ eye = eye[None, :, :].expand((bsz, -1, -1))
670
+
671
+ block_p, block = self.compute_block(distance, height)
672
+ head_p = self.compute_head(height)
673
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
674
+ head = head.masked_fill(eye, 0)
675
+ child = head.transpose(1, 2)
676
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
677
+
678
+ rel_list = []
679
+ if 'head' in self.relations:
680
+ rel_list.append(head)
681
+ if 'child' in self.relations:
682
+ rel_list.append(child)
683
+ if 'cibling' in self.relations:
684
+ rel_list.append(cibling)
685
+
686
+ rel = torch.stack(rel_list, dim=1)
687
+
688
+ rel_weight = self.rel_weight
689
+
690
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
691
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
692
+
693
+ return att_mask, cibling, head, block
694
+
695
+ def encode(self, x, pos, att_mask=None, context_layers=False):
696
+ """Structformer encoding process."""
697
+
698
+ if context_layers:
699
+ """Standard transformer encode process."""
700
+ h = self.emb(x)
701
+ if hasattr(self, 'pos_emb'):
702
+ h = h + self.pos_emb(pos)
703
+ h_list = []
704
+ visibility = self.visibility(x, x.device)
705
+ for i in range(self.n_context_layers):
706
+ h_list.append(h)
707
+ h = self.context_layers[i](
708
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
709
+
710
+ output = h
711
+ h_array = torch.stack(h_list, dim=2)
712
+ return output
713
+
714
+ else:
715
+ visibility = self.visibility(x, x.device)
716
+ h = self.emb(x)
717
+ if hasattr(self, 'pos_emb'):
718
+ assert pos.max() < 500
719
+ h = h + self.pos_emb(pos)
720
+ for i in range(self.nlayers):
721
+ h = self.layers[i](
722
+ h.transpose(0, 1), attn_mask=att_mask[i],
723
+ key_padding_mask=visibility).transpose(0, 1)
724
+ return h
725
+
726
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
727
+
728
+ x = input_ids
729
+ batch_size, length = x.size()
730
+
731
+ if position_ids is None:
732
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
733
+
734
+ context_layers_output = None
735
+ if self.n_context_layers > 0:
736
+ context_layers_output = self.encode(x, pos, context_layers=True)
737
+
738
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
739
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
740
+
741
+ raw_output = self.encode(x, pos, att_mask)
742
+ raw_output = self.norm(raw_output)
743
+ raw_output = self.drop(raw_output)
744
+
745
+ output = self.output_layer(raw_output)
746
+
747
+ loss = None
748
+ if labels is not None:
749
+ loss_fct = nn.CrossEntropyLoss()
750
+ loss = loss_fct(output.view(batch_size * length, -1), labels.reshape(-1))
751
+
752
+ return MaskedLMOutput(
753
+ loss=loss, # shape: 1
754
+ logits=output, # shape: (batch_size * length, ntokens)
755
+ hidden_states=None,
756
+ attentions=None,
757
+ )
758
+
759
+
760
+
761
+
762
+ class StructFormerClassification(Transformer):
763
+ """StructFormer model."""
764
+
765
+ def __init__(self,
766
+ hidden_size,
767
+ n_context_layers,
768
+ nlayers,
769
+ ntokens,
770
+ nhead=8,
771
+ dropout=0.1,
772
+ dropatt=0.1,
773
+ relative_bias=False,
774
+ pos_emb=False,
775
+ pad=0,
776
+ n_parser_layers=4,
777
+ conv_size=9,
778
+ relations=('head', 'child'),
779
+ weight_act='softmax',
780
+ config=None,
781
+ ):
782
+
783
+
784
+ super(StructFormerClassification, self).__init__(
785
+ hidden_size,
786
+ nlayers,
787
+ ntokens,
788
+ nhead=nhead,
789
+ dropout=dropout,
790
+ dropatt=dropatt,
791
+ relative_bias=relative_bias,
792
+ pos_emb=pos_emb,
793
+ pad=pad)
794
+
795
+ self.num_labels = config.num_labels
796
+ self.config = config
797
+
798
+ if n_context_layers > 0:
799
+ self.context_layers = nn.ModuleList([
800
+ TransformerLayer(hidden_size, nhead, hidden_size * 4, dropout,
801
+ dropatt=dropatt, relative_bias=relative_bias)
802
+ for _ in range(n_context_layers)])
803
+
804
+ self.parser_layers = nn.ModuleList([
805
+ nn.Sequential(Conv1d(hidden_size, conv_size),
806
+ nn.LayerNorm(hidden_size, elementwise_affine=False),
807
+ nn.Tanh()) for i in range(n_parser_layers)])
808
+
809
+ self.distance_ff = nn.Sequential(
810
+ Conv1d(hidden_size, 2),
811
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
812
+ nn.Linear(hidden_size, 1))
813
+
814
+ self.height_ff = nn.Sequential(
815
+ nn.Linear(hidden_size, hidden_size),
816
+ nn.LayerNorm(hidden_size, elementwise_affine=False), nn.Tanh(),
817
+ nn.Linear(hidden_size, 1))
818
+
819
+ n_rel = len(relations)
820
+ self._rel_weight = nn.Parameter(torch.zeros((nlayers, nhead, n_rel)))
821
+ self._rel_weight.data.normal_(0, 0.1)
822
+
823
+ self._scaler = nn.Parameter(torch.zeros(2))
824
+
825
+ self.n_parse_layers = n_parser_layers
826
+ self.n_context_layers = n_context_layers
827
+ self.weight_act = weight_act
828
+ self.relations = relations
829
+
830
+ self.classifier = RobertaClassificationHead(config)
831
+
832
+ @property
833
+ def scaler(self):
834
+ return self._scaler.exp()
835
+
836
+ @property
837
+ def rel_weight(self):
838
+ if self.weight_act == 'sigmoid':
839
+ return torch.sigmoid(self._rel_weight)
840
+ elif self.weight_act == 'softmax':
841
+ return torch.softmax(self._rel_weight, dim=-1)
842
+
843
+ def parse(self, x, pos, embeds=None):
844
+ """Parse input sentence.
845
+ Args:
846
+ x: input tokens (required).
847
+ pos: position for each token (optional).
848
+ Returns:
849
+ distance: syntactic distance
850
+ height: syntactic height
851
+ """
852
+
853
+ mask = (x != self.pad)
854
+ mask_shifted = F.pad(mask[:, 1:], (0, 1), value=0)
855
+
856
+
857
+ if embeds is not None:
858
+ h = embeds
859
+ else:
860
+ h = self.emb(x)
861
+
862
+ for i in range(self.n_parse_layers):
863
+ h = h.masked_fill(~mask[:, :, None], 0)
864
+ h = self.parser_layers[i](h)
865
+
866
+ height = self.height_ff(h).squeeze(-1)
867
+ height.masked_fill_(~mask, -1e9)
868
+
869
+ distance = self.distance_ff(h).squeeze(-1)
870
+ distance.masked_fill_(~mask_shifted, 1e9)
871
+
872
+ # Calbrating the distance and height to the same level
873
+ length = distance.size(1)
874
+ height_max = height[:, None, :].expand(-1, length, -1)
875
+ height_max = torch.cummax(
876
+ height_max.triu(0) - torch.ones_like(height_max).tril(-1) * 1e9,
877
+ dim=-1)[0].triu(0)
878
+
879
+ margin_left = torch.relu(
880
+ F.pad(distance[:, :-1, None], (0, 0, 1, 0), value=1e9) - height_max)
881
+ margin_right = torch.relu(distance[:, None, :] - height_max)
882
+ margin = torch.where(margin_left > margin_right, margin_right,
883
+ margin_left).triu(0)
884
+
885
+ margin_mask = torch.stack([mask_shifted] + [mask] * (length - 1), dim=1)
886
+ margin.masked_fill_(~margin_mask, 0)
887
+ margin = margin.max()
888
+
889
+ distance = distance - margin
890
+
891
+ return distance, height
892
+
893
+ def compute_block(self, distance, height):
894
+ """Compute constituents from distance and height."""
895
+
896
+ beta_logits = (distance[:, None, :] - height[:, :, None]) * self.scaler[0]
897
+
898
+ gamma = torch.sigmoid(-beta_logits)
899
+ ones = torch.ones_like(gamma)
900
+
901
+ block_mask_left = cummin(
902
+ gamma.tril(-1) + ones.triu(0), reverse=True, max_value=1)
903
+ block_mask_left = block_mask_left - F.pad(
904
+ block_mask_left[:, :, :-1], (1, 0), value=0)
905
+ block_mask_left.tril_(0)
906
+
907
+ block_mask_right = cummin(
908
+ gamma.triu(0) + ones.tril(-1), exclusive=True, max_value=1)
909
+ block_mask_right = block_mask_right - F.pad(
910
+ block_mask_right[:, :, 1:], (0, 1), value=0)
911
+ block_mask_right.triu_(0)
912
+
913
+ block_p = block_mask_left[:, :, :, None] * block_mask_right[:, :, None, :]
914
+ block = cumsum(block_mask_left).tril(0) + cumsum(
915
+ block_mask_right, reverse=True).triu(1)
916
+
917
+ return block_p, block
918
+
919
+ def compute_head(self, height):
920
+ """Estimate head for each constituent."""
921
+
922
+ _, length = height.size()
923
+ head_logits = height * self.scaler[1]
924
+ index = torch.arange(length, device=height.device)
925
+
926
+ mask = (index[:, None, None] <= index[None, None, :]) * (
927
+ index[None, None, :] <= index[None, :, None])
928
+ head_logits = head_logits[:, None, None, :].repeat(1, length, length, 1)
929
+ head_logits.masked_fill_(~mask[None, :, :, :], -1e9)
930
+
931
+ head_p = torch.softmax(head_logits, dim=-1)
932
+
933
+ return head_p
934
+
935
+ def generate_mask(self, x, distance, height):
936
+ """Compute head and cibling distribution for each token."""
937
+
938
+ bsz, length = x.size()
939
+
940
+ eye = torch.eye(length, device=x.device, dtype=torch.bool)
941
+ eye = eye[None, :, :].expand((bsz, -1, -1))
942
+
943
+ block_p, block = self.compute_block(distance, height)
944
+ head_p = self.compute_head(height)
945
+ head = torch.einsum('blij,bijh->blh', block_p, head_p)
946
+ head = head.masked_fill(eye, 0)
947
+ child = head.transpose(1, 2)
948
+ cibling = torch.bmm(head, child).masked_fill(eye, 0)
949
+
950
+ rel_list = []
951
+ if 'head' in self.relations:
952
+ rel_list.append(head)
953
+ if 'child' in self.relations:
954
+ rel_list.append(child)
955
+ if 'cibling' in self.relations:
956
+ rel_list.append(cibling)
957
+
958
+ rel = torch.stack(rel_list, dim=1)
959
+
960
+ rel_weight = self.rel_weight
961
+
962
+ dep = torch.einsum('lhr,brij->lbhij', rel_weight, rel)
963
+ att_mask = dep.reshape(self.nlayers, bsz * self.nhead, length, length)
964
+
965
+ return att_mask, cibling, head, block
966
+
967
+ def encode(self, x, pos, att_mask=None, context_layers=False):
968
+ """Structformer encoding process."""
969
+
970
+ if context_layers:
971
+ """Standard transformer encode process."""
972
+ h = self.emb(x)
973
+ if hasattr(self, 'pos_emb'):
974
+ h = h + self.pos_emb(pos)
975
+ h_list = []
976
+ visibility = self.visibility(x, x.device)
977
+ for i in range(self.n_context_layers):
978
+ h_list.append(h)
979
+ h = self.context_layers[i](
980
+ h.transpose(0, 1), key_padding_mask=visibility).transpose(0, 1)
981
+
982
+ output = h
983
+ h_array = torch.stack(h_list, dim=2)
984
+ return output
985
+
986
+ else:
987
+ visibility = self.visibility(x, x.device)
988
+ h = self.emb(x)
989
+ if hasattr(self, 'pos_emb'):
990
+ assert pos.max() < 500
991
+ h = h + self.pos_emb(pos)
992
+ for i in range(self.nlayers):
993
+ h = self.layers[i](
994
+ h.transpose(0, 1), attn_mask=att_mask[i],
995
+ key_padding_mask=visibility).transpose(0, 1)
996
+ return h
997
+
998
+ def forward(self, input_ids, labels=None, position_ids=None, **kwargs):
999
+
1000
+ x = input_ids
1001
+ batch_size, length = x.size()
1002
+
1003
+ if position_ids is None:
1004
+ pos = torch.arange(length, device=x.device).expand(batch_size, length)
1005
+
1006
+ context_layers_output = None
1007
+ if self.n_context_layers > 0:
1008
+ context_layers_output = self.encode(x, pos, context_layers=True)
1009
+
1010
+ distance, height = self.parse(x, pos, embeds=context_layers_output)
1011
+ att_mask, cibling, head, block = self.generate_mask(x, distance, height)
1012
+
1013
+ raw_output = self.encode(x, pos, att_mask)
1014
+ raw_output = self.norm(raw_output)
1015
+ raw_output = self.drop(raw_output)
1016
+
1017
+ #output = self.output_layer(raw_output)
1018
+ logits = self.classifier(raw_output)
1019
+
1020
+ loss = None
1021
+ if labels is not None:
1022
+ if self.config.problem_type is None:
1023
+ if self.num_labels == 1:
1024
+ self.config.problem_type = "regression"
1025
+ elif self.num_labels > 1 and (labels.dtype == torch.long or labels.dtype == torch.int):
1026
+ self.config.problem_type = "single_label_classification"
1027
+ else:
1028
+ self.config.problem_type = "multi_label_classification"
1029
+
1030
+ if self.config.problem_type == "regression":
1031
+ loss_fct = MSELoss()
1032
+ if self.num_labels == 1:
1033
+ loss = loss_fct(logits.squeeze(), labels.squeeze())
1034
+ else:
1035
+ loss = loss_fct(logits, labels)
1036
+ elif self.config.problem_type == "single_label_classification":
1037
+ loss_fct = CrossEntropyLoss()
1038
+ loss = loss_fct(logits.view(-1, self.num_labels), labels.view(-1))
1039
+ elif self.config.problem_type == "multi_label_classification":
1040
+ loss_fct = BCEWithLogitsLoss()
1041
+ loss = loss_fct(logits, labels)
1042
+
1043
+
1044
+ return SequenceClassifierOutput(
1045
+ loss=loss,
1046
+ logits=logits,
1047
+ hidden_states=None,
1048
+ attentions=None,
1049
+ )
1050
+
1051
+
1052
+
1053
+ ##########################################
1054
+ # HuggingFace Model
1055
+ ##########################################
1056
+ class StructformerModel(PreTrainedModel):
1057
+ config_class = StructformerConfig
1058
+
1059
+ def __init__(self, config):
1060
+ super().__init__(config)
1061
+ self.model = StructFormer(
1062
+ hidden_size=config.hidden_size,
1063
+ n_context_layers=config.n_context_layers,
1064
+ nlayers=config.nlayers,
1065
+ ntokens=config.ntokens,
1066
+ nhead=config.nhead,
1067
+ dropout=config.dropout,
1068
+ dropatt=config.dropatt,
1069
+ relative_bias=config.relative_bias,
1070
+ pos_emb=config.pos_emb,
1071
+ pad=config.pad,
1072
+ n_parser_layers=config.n_parser_layers,
1073
+ conv_size=config.conv_size,
1074
+ relations=config.relations,
1075
+ weight_act=config.weight_act
1076
+ )
1077
+
1078
+ def forward(self, input_ids, labels=None, **kwargs):
1079
+ return self.model(input_ids, labels=labels, **kwargs)
1080
+
1081
+
1082
+
1083
+ class StructformerModelForSequenceClassification(PreTrainedModel):
1084
+ config_class = StructformerConfig
1085
+ def __init__(self, config):
1086
+ super().__init__(config)
1087
+ self.model = StructFormerClassification(
1088
+ hidden_size=config.hidden_size,
1089
+ n_context_layers=config.n_context_layers,
1090
+ nlayers=config.nlayers,
1091
+ ntokens=config.ntokens,
1092
+ nhead=config.nhead,
1093
+ dropout=config.dropout,
1094
+ dropatt=config.dropatt,
1095
+ relative_bias=config.relative_bias,
1096
+ pos_emb=config.pos_emb,
1097
+ pad=config.pad,
1098
+ n_parser_layers=config.n_parser_layers,
1099
+ conv_size=config.conv_size,
1100
+ relations=config.relations,
1101
+ weight_act=config.weight_act,
1102
+ config=config)
1103
+
1104
+ def _init_weights(self, module):
1105
+ """Initialize the weights"""
1106
+ if isinstance(module, nn.Linear):
1107
+ # Slightly different from the TF version which uses truncated_normal for initialization
1108
+ # cf https://github.com/pytorch/pytorch/pull/5617
1109
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1110
+ if module.bias is not None:
1111
+ module.bias.data.zero_()
1112
+ elif isinstance(module, nn.Embedding):
1113
+ module.weight.data.normal_(mean=0.0, std=self.config.initializer_range)
1114
+ if module.padding_idx is not None:
1115
+ module.weight.data[module.padding_idx].zero_()
1116
+ elif isinstance(module, nn.LayerNorm):
1117
+ if module.bias is not None:
1118
+ module.bias.data.zero_()
1119
+ module.weight.data.fill_(1.0)
1120
+
1121
+
1122
+ def forward(self, input_ids, labels=None, **kwargs):
1123
+ return self.model(input_ids, labels=labels, **kwargs)
finetune/control_raising_control/checkpoint-400/tokenizer_config.json ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "add_prefix_space": false,
3
+ "bos_token": {
4
+ "__type": "AddedToken",
5
+ "content": "<s>",
6
+ "lstrip": false,
7
+ "normalized": true,
8
+ "rstrip": false,
9
+ "single_word": false
10
+ },
11
+ "cls_token": {
12
+ "__type": "AddedToken",
13
+ "content": "<s>",
14
+ "lstrip": false,
15
+ "normalized": true,
16
+ "rstrip": false,
17
+ "single_word": false
18
+ },
19
+ "eos_token": {
20
+ "__type": "AddedToken",
21
+ "content": "</s>",
22
+ "lstrip": false,
23
+ "normalized": true,
24
+ "rstrip": false,
25
+ "single_word": false
26
+ },
27
+ "errors": "replace",
28
+ "mask_token": {
29
+ "__type": "AddedToken",
30
+ "content": "<mask>",
31
+ "lstrip": true,
32
+ "normalized": false,
33
+ "rstrip": false,
34
+ "single_word": false
35
+ },
36
+ "model_max_length": 512,
37
+ "name_or_path": "omarmomen/structformer_s1_final_with_pos",
38
+ "pad_token": {
39
+ "__type": "AddedToken",
40
+ "content": "<pad>",
41
+ "lstrip": false,
42
+ "normalized": true,
43
+ "rstrip": false,
44
+ "single_word": false
45
+ },
46
+ "sep_token": {
47
+ "__type": "AddedToken",
48
+ "content": "</s>",
49
+ "lstrip": false,
50
+ "normalized": true,
51
+ "rstrip": false,
52
+ "single_word": false
53
+ },
54
+ "special_tokens_map_file": null,
55
+ "tokenizer_class": "RobertaTokenizer",
56
+ "trim_offsets": true,
57
+ "unk_token": {
58
+ "__type": "AddedToken",
59
+ "content": "<unk>",
60
+ "lstrip": false,
61
+ "normalized": true,
62
+ "rstrip": false,
63
+ "single_word": false
64
+ }
65
+ }
finetune/control_raising_control/checkpoint-400/trainer_state.json ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "best_metric": 0.8853014990164824,
3
+ "best_model_checkpoint": "finetune_results/omarmomen/structformer_s1_final_with_pos/control_raising_control/checkpoint-400",
4
+ "epoch": 7.2727272727272725,
5
+ "global_step": 400,
6
+ "is_hyper_param_search": false,
7
+ "is_local_process_zero": true,
8
+ "is_world_process_zero": true,
9
+ "log_history": [
10
+ {
11
+ "epoch": 7.27,
12
+ "eval_accuracy": 0.8736362457275391,
13
+ "eval_f1": 0.8853014990164824,
14
+ "eval_loss": 1.199914813041687,
15
+ "eval_mcc": 0.7674272477853503,
16
+ "eval_runtime": 25.9995,
17
+ "eval_samples_per_second": 514.702,
18
+ "eval_steps_per_second": 64.347,
19
+ "step": 400
20
+ }
21
+ ],
22
+ "max_steps": 550,
23
+ "num_train_epochs": 10,
24
+ "total_flos": 3989207699389440.0,
25
+ "trial_name": null,
26
+ "trial_params": null
27
+ }
finetune/control_raising_control/checkpoint-400/training_args.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f1dcbfabab2b2121e41e9e1698127c1906a1da248b0530c01c7064746c08a2fc
3
+ size 3567