Vivek commited on
Commit
9c9f971
1 Parent(s): 94f6551

deleting files

Browse files
Files changed (2) hide show
  1. GPT2(error).ipynb +0 -1074
  2. Untitled330.ipynb +0 -470
GPT2(error).ipynb DELETED
@@ -1,1074 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "name": "GPT2(error).ipynb",
7
- "provenance": [],
8
- "collapsed_sections": []
9
- },
10
- "kernelspec": {
11
- "name": "python3",
12
- "display_name": "Python 3"
13
- },
14
- "language_info": {
15
- "name": "python"
16
- },
17
- "widgets": {
18
- "application/vnd.jupyter.widget-state+json": {
19
- "1b266a2c1cf646a392a46e39586282b3": {
20
- "model_module": "@jupyter-widgets/controls",
21
- "model_name": "HBoxModel",
22
- "state": {
23
- "_view_name": "HBoxView",
24
- "_dom_classes": [],
25
- "_model_name": "HBoxModel",
26
- "_view_module": "@jupyter-widgets/controls",
27
- "_model_module_version": "1.5.0",
28
- "_view_count": null,
29
- "_view_module_version": "1.5.0",
30
- "box_style": "",
31
- "layout": "IPY_MODEL_8ecfcf14981c4d82b5d9d3839a496f0b",
32
- "_model_module": "@jupyter-widgets/controls",
33
- "children": [
34
- "IPY_MODEL_16b07572ac0d46798b2c2a292c3f9143",
35
- "IPY_MODEL_cf412ff73fc647908154abc9b2847f38"
36
- ]
37
- }
38
- },
39
- "8ecfcf14981c4d82b5d9d3839a496f0b": {
40
- "model_module": "@jupyter-widgets/base",
41
- "model_name": "LayoutModel",
42
- "state": {
43
- "_view_name": "LayoutView",
44
- "grid_template_rows": null,
45
- "right": null,
46
- "justify_content": null,
47
- "_view_module": "@jupyter-widgets/base",
48
- "overflow": null,
49
- "_model_module_version": "1.2.0",
50
- "_view_count": null,
51
- "flex_flow": null,
52
- "width": null,
53
- "min_width": null,
54
- "border": null,
55
- "align_items": null,
56
- "bottom": null,
57
- "_model_module": "@jupyter-widgets/base",
58
- "top": null,
59
- "grid_column": null,
60
- "overflow_y": null,
61
- "overflow_x": null,
62
- "grid_auto_flow": null,
63
- "grid_area": null,
64
- "grid_template_columns": null,
65
- "flex": null,
66
- "_model_name": "LayoutModel",
67
- "justify_items": null,
68
- "grid_row": null,
69
- "max_height": null,
70
- "align_content": null,
71
- "visibility": null,
72
- "align_self": null,
73
- "height": null,
74
- "min_height": null,
75
- "padding": null,
76
- "grid_auto_rows": null,
77
- "grid_gap": null,
78
- "max_width": null,
79
- "order": null,
80
- "_view_module_version": "1.2.0",
81
- "grid_template_areas": null,
82
- "object_position": null,
83
- "object_fit": null,
84
- "grid_auto_columns": null,
85
- "margin": null,
86
- "display": null,
87
- "left": null
88
- }
89
- },
90
- "16b07572ac0d46798b2c2a292c3f9143": {
91
- "model_module": "@jupyter-widgets/controls",
92
- "model_name": "FloatProgressModel",
93
- "state": {
94
- "_view_name": "ProgressView",
95
- "style": "IPY_MODEL_6ebc21286ae843e5b9ba4df8f4cebfe0",
96
- "_dom_classes": [],
97
- "description": "Downloading: 100%",
98
- "_model_name": "FloatProgressModel",
99
- "bar_style": "success",
100
- "max": 1042301,
101
- "_view_module": "@jupyter-widgets/controls",
102
- "_model_module_version": "1.5.0",
103
- "value": 1042301,
104
- "_view_count": null,
105
- "_view_module_version": "1.5.0",
106
- "orientation": "horizontal",
107
- "min": 0,
108
- "description_tooltip": null,
109
- "_model_module": "@jupyter-widgets/controls",
110
- "layout": "IPY_MODEL_48246be80e82429da2d48f9d4a1aaf0a"
111
- }
112
- },
113
- "cf412ff73fc647908154abc9b2847f38": {
114
- "model_module": "@jupyter-widgets/controls",
115
- "model_name": "HTMLModel",
116
- "state": {
117
- "_view_name": "HTMLView",
118
- "style": "IPY_MODEL_5754900e885d4f509ede058b186fcab6",
119
- "_dom_classes": [],
120
- "description": "",
121
- "_model_name": "HTMLModel",
122
- "placeholder": "​",
123
- "_view_module": "@jupyter-widgets/controls",
124
- "_model_module_version": "1.5.0",
125
- "value": " 1.04M/1.04M [00:06<00:00, 154kB/s]",
126
- "_view_count": null,
127
- "_view_module_version": "1.5.0",
128
- "description_tooltip": null,
129
- "_model_module": "@jupyter-widgets/controls",
130
- "layout": "IPY_MODEL_d0434381119c46489e17fcbccd9755ea"
131
- }
132
- },
133
- "6ebc21286ae843e5b9ba4df8f4cebfe0": {
134
- "model_module": "@jupyter-widgets/controls",
135
- "model_name": "ProgressStyleModel",
136
- "state": {
137
- "_view_name": "StyleView",
138
- "_model_name": "ProgressStyleModel",
139
- "description_width": "initial",
140
- "_view_module": "@jupyter-widgets/base",
141
- "_model_module_version": "1.5.0",
142
- "_view_count": null,
143
- "_view_module_version": "1.2.0",
144
- "bar_color": null,
145
- "_model_module": "@jupyter-widgets/controls"
146
- }
147
- },
148
- "48246be80e82429da2d48f9d4a1aaf0a": {
149
- "model_module": "@jupyter-widgets/base",
150
- "model_name": "LayoutModel",
151
- "state": {
152
- "_view_name": "LayoutView",
153
- "grid_template_rows": null,
154
- "right": null,
155
- "justify_content": null,
156
- "_view_module": "@jupyter-widgets/base",
157
- "overflow": null,
158
- "_model_module_version": "1.2.0",
159
- "_view_count": null,
160
- "flex_flow": null,
161
- "width": null,
162
- "min_width": null,
163
- "border": null,
164
- "align_items": null,
165
- "bottom": null,
166
- "_model_module": "@jupyter-widgets/base",
167
- "top": null,
168
- "grid_column": null,
169
- "overflow_y": null,
170
- "overflow_x": null,
171
- "grid_auto_flow": null,
172
- "grid_area": null,
173
- "grid_template_columns": null,
174
- "flex": null,
175
- "_model_name": "LayoutModel",
176
- "justify_items": null,
177
- "grid_row": null,
178
- "max_height": null,
179
- "align_content": null,
180
- "visibility": null,
181
- "align_self": null,
182
- "height": null,
183
- "min_height": null,
184
- "padding": null,
185
- "grid_auto_rows": null,
186
- "grid_gap": null,
187
- "max_width": null,
188
- "order": null,
189
- "_view_module_version": "1.2.0",
190
- "grid_template_areas": null,
191
- "object_position": null,
192
- "object_fit": null,
193
- "grid_auto_columns": null,
194
- "margin": null,
195
- "display": null,
196
- "left": null
197
- }
198
- },
199
- "5754900e885d4f509ede058b186fcab6": {
200
- "model_module": "@jupyter-widgets/controls",
201
- "model_name": "DescriptionStyleModel",
202
- "state": {
203
- "_view_name": "StyleView",
204
- "_model_name": "DescriptionStyleModel",
205
- "description_width": "",
206
- "_view_module": "@jupyter-widgets/base",
207
- "_model_module_version": "1.5.0",
208
- "_view_count": null,
209
- "_view_module_version": "1.2.0",
210
- "_model_module": "@jupyter-widgets/controls"
211
- }
212
- },
213
- "d0434381119c46489e17fcbccd9755ea": {
214
- "model_module": "@jupyter-widgets/base",
215
- "model_name": "LayoutModel",
216
- "state": {
217
- "_view_name": "LayoutView",
218
- "grid_template_rows": null,
219
- "right": null,
220
- "justify_content": null,
221
- "_view_module": "@jupyter-widgets/base",
222
- "overflow": null,
223
- "_model_module_version": "1.2.0",
224
- "_view_count": null,
225
- "flex_flow": null,
226
- "width": null,
227
- "min_width": null,
228
- "border": null,
229
- "align_items": null,
230
- "bottom": null,
231
- "_model_module": "@jupyter-widgets/base",
232
- "top": null,
233
- "grid_column": null,
234
- "overflow_y": null,
235
- "overflow_x": null,
236
- "grid_auto_flow": null,
237
- "grid_area": null,
238
- "grid_template_columns": null,
239
- "flex": null,
240
- "_model_name": "LayoutModel",
241
- "justify_items": null,
242
- "grid_row": null,
243
- "max_height": null,
244
- "align_content": null,
245
- "visibility": null,
246
- "align_self": null,
247
- "height": null,
248
- "min_height": null,
249
- "padding": null,
250
- "grid_auto_rows": null,
251
- "grid_gap": null,
252
- "max_width": null,
253
- "order": null,
254
- "_view_module_version": "1.2.0",
255
- "grid_template_areas": null,
256
- "object_position": null,
257
- "object_fit": null,
258
- "grid_auto_columns": null,
259
- "margin": null,
260
- "display": null,
261
- "left": null
262
- }
263
- },
264
- "73c4b8bc05f64477aa03d767f4483795": {
265
- "model_module": "@jupyter-widgets/controls",
266
- "model_name": "HBoxModel",
267
- "state": {
268
- "_view_name": "HBoxView",
269
- "_dom_classes": [],
270
- "_model_name": "HBoxModel",
271
- "_view_module": "@jupyter-widgets/controls",
272
- "_model_module_version": "1.5.0",
273
- "_view_count": null,
274
- "_view_module_version": "1.5.0",
275
- "box_style": "",
276
- "layout": "IPY_MODEL_6123827ad5964b4b8a17aaca618b4768",
277
- "_model_module": "@jupyter-widgets/controls",
278
- "children": [
279
- "IPY_MODEL_5327d425e74d4a599214282b9b70d58b",
280
- "IPY_MODEL_974490d04f18407f9f5a5785b2802c0a"
281
- ]
282
- }
283
- },
284
- "6123827ad5964b4b8a17aaca618b4768": {
285
- "model_module": "@jupyter-widgets/base",
286
- "model_name": "LayoutModel",
287
- "state": {
288
- "_view_name": "LayoutView",
289
- "grid_template_rows": null,
290
- "right": null,
291
- "justify_content": null,
292
- "_view_module": "@jupyter-widgets/base",
293
- "overflow": null,
294
- "_model_module_version": "1.2.0",
295
- "_view_count": null,
296
- "flex_flow": null,
297
- "width": null,
298
- "min_width": null,
299
- "border": null,
300
- "align_items": null,
301
- "bottom": null,
302
- "_model_module": "@jupyter-widgets/base",
303
- "top": null,
304
- "grid_column": null,
305
- "overflow_y": null,
306
- "overflow_x": null,
307
- "grid_auto_flow": null,
308
- "grid_area": null,
309
- "grid_template_columns": null,
310
- "flex": null,
311
- "_model_name": "LayoutModel",
312
- "justify_items": null,
313
- "grid_row": null,
314
- "max_height": null,
315
- "align_content": null,
316
- "visibility": null,
317
- "align_self": null,
318
- "height": null,
319
- "min_height": null,
320
- "padding": null,
321
- "grid_auto_rows": null,
322
- "grid_gap": null,
323
- "max_width": null,
324
- "order": null,
325
- "_view_module_version": "1.2.0",
326
- "grid_template_areas": null,
327
- "object_position": null,
328
- "object_fit": null,
329
- "grid_auto_columns": null,
330
- "margin": null,
331
- "display": null,
332
- "left": null
333
- }
334
- },
335
- "5327d425e74d4a599214282b9b70d58b": {
336
- "model_module": "@jupyter-widgets/controls",
337
- "model_name": "FloatProgressModel",
338
- "state": {
339
- "_view_name": "ProgressView",
340
- "style": "IPY_MODEL_c3cc1723c39a4d74b2ab83bd23b5fcce",
341
- "_dom_classes": [],
342
- "description": "Downloading: 100%",
343
- "_model_name": "FloatProgressModel",
344
- "bar_style": "success",
345
- "max": 456318,
346
- "_view_module": "@jupyter-widgets/controls",
347
- "_model_module_version": "1.5.0",
348
- "value": 456318,
349
- "_view_count": null,
350
- "_view_module_version": "1.5.0",
351
- "orientation": "horizontal",
352
- "min": 0,
353
- "description_tooltip": null,
354
- "_model_module": "@jupyter-widgets/controls",
355
- "layout": "IPY_MODEL_391d59bf8d2845f88a83dc25c7cf89f3"
356
- }
357
- },
358
- "974490d04f18407f9f5a5785b2802c0a": {
359
- "model_module": "@jupyter-widgets/controls",
360
- "model_name": "HTMLModel",
361
- "state": {
362
- "_view_name": "HTMLView",
363
- "style": "IPY_MODEL_d60fa9fe71444784b78bdfba6ed6a9e1",
364
- "_dom_classes": [],
365
- "description": "",
366
- "_model_name": "HTMLModel",
367
- "placeholder": "​",
368
- "_view_module": "@jupyter-widgets/controls",
369
- "_model_module_version": "1.5.0",
370
- "value": " 456k/456k [00:04<00:00, 96.1kB/s]",
371
- "_view_count": null,
372
- "_view_module_version": "1.5.0",
373
- "description_tooltip": null,
374
- "_model_module": "@jupyter-widgets/controls",
375
- "layout": "IPY_MODEL_41a3b55e5e264b85ada9558e5777790f"
376
- }
377
- },
378
- "c3cc1723c39a4d74b2ab83bd23b5fcce": {
379
- "model_module": "@jupyter-widgets/controls",
380
- "model_name": "ProgressStyleModel",
381
- "state": {
382
- "_view_name": "StyleView",
383
- "_model_name": "ProgressStyleModel",
384
- "description_width": "initial",
385
- "_view_module": "@jupyter-widgets/base",
386
- "_model_module_version": "1.5.0",
387
- "_view_count": null,
388
- "_view_module_version": "1.2.0",
389
- "bar_color": null,
390
- "_model_module": "@jupyter-widgets/controls"
391
- }
392
- },
393
- "391d59bf8d2845f88a83dc25c7cf89f3": {
394
- "model_module": "@jupyter-widgets/base",
395
- "model_name": "LayoutModel",
396
- "state": {
397
- "_view_name": "LayoutView",
398
- "grid_template_rows": null,
399
- "right": null,
400
- "justify_content": null,
401
- "_view_module": "@jupyter-widgets/base",
402
- "overflow": null,
403
- "_model_module_version": "1.2.0",
404
- "_view_count": null,
405
- "flex_flow": null,
406
- "width": null,
407
- "min_width": null,
408
- "border": null,
409
- "align_items": null,
410
- "bottom": null,
411
- "_model_module": "@jupyter-widgets/base",
412
- "top": null,
413
- "grid_column": null,
414
- "overflow_y": null,
415
- "overflow_x": null,
416
- "grid_auto_flow": null,
417
- "grid_area": null,
418
- "grid_template_columns": null,
419
- "flex": null,
420
- "_model_name": "LayoutModel",
421
- "justify_items": null,
422
- "grid_row": null,
423
- "max_height": null,
424
- "align_content": null,
425
- "visibility": null,
426
- "align_self": null,
427
- "height": null,
428
- "min_height": null,
429
- "padding": null,
430
- "grid_auto_rows": null,
431
- "grid_gap": null,
432
- "max_width": null,
433
- "order": null,
434
- "_view_module_version": "1.2.0",
435
- "grid_template_areas": null,
436
- "object_position": null,
437
- "object_fit": null,
438
- "grid_auto_columns": null,
439
- "margin": null,
440
- "display": null,
441
- "left": null
442
- }
443
- },
444
- "d60fa9fe71444784b78bdfba6ed6a9e1": {
445
- "model_module": "@jupyter-widgets/controls",
446
- "model_name": "DescriptionStyleModel",
447
- "state": {
448
- "_view_name": "StyleView",
449
- "_model_name": "DescriptionStyleModel",
450
- "description_width": "",
451
- "_view_module": "@jupyter-widgets/base",
452
- "_model_module_version": "1.5.0",
453
- "_view_count": null,
454
- "_view_module_version": "1.2.0",
455
- "_model_module": "@jupyter-widgets/controls"
456
- }
457
- },
458
- "41a3b55e5e264b85ada9558e5777790f": {
459
- "model_module": "@jupyter-widgets/base",
460
- "model_name": "LayoutModel",
461
- "state": {
462
- "_view_name": "LayoutView",
463
- "grid_template_rows": null,
464
- "right": null,
465
- "justify_content": null,
466
- "_view_module": "@jupyter-widgets/base",
467
- "overflow": null,
468
- "_model_module_version": "1.2.0",
469
- "_view_count": null,
470
- "flex_flow": null,
471
- "width": null,
472
- "min_width": null,
473
- "border": null,
474
- "align_items": null,
475
- "bottom": null,
476
- "_model_module": "@jupyter-widgets/base",
477
- "top": null,
478
- "grid_column": null,
479
- "overflow_y": null,
480
- "overflow_x": null,
481
- "grid_auto_flow": null,
482
- "grid_area": null,
483
- "grid_template_columns": null,
484
- "flex": null,
485
- "_model_name": "LayoutModel",
486
- "justify_items": null,
487
- "grid_row": null,
488
- "max_height": null,
489
- "align_content": null,
490
- "visibility": null,
491
- "align_self": null,
492
- "height": null,
493
- "min_height": null,
494
- "padding": null,
495
- "grid_auto_rows": null,
496
- "grid_gap": null,
497
- "max_width": null,
498
- "order": null,
499
- "_view_module_version": "1.2.0",
500
- "grid_template_areas": null,
501
- "object_position": null,
502
- "object_fit": null,
503
- "grid_auto_columns": null,
504
- "margin": null,
505
- "display": null,
506
- "left": null
507
- }
508
- },
509
- "aa4d6e2e9ac44e9bb40b7daccc91ee83": {
510
- "model_module": "@jupyter-widgets/controls",
511
- "model_name": "HBoxModel",
512
- "state": {
513
- "_view_name": "HBoxView",
514
- "_dom_classes": [],
515
- "_model_name": "HBoxModel",
516
- "_view_module": "@jupyter-widgets/controls",
517
- "_model_module_version": "1.5.0",
518
- "_view_count": null,
519
- "_view_module_version": "1.5.0",
520
- "box_style": "",
521
- "layout": "IPY_MODEL_c3b054972a6145d1ad03ca938a7ade9c",
522
- "_model_module": "@jupyter-widgets/controls",
523
- "children": [
524
- "IPY_MODEL_644f4a69db534dd4a11172e5d010e8fe",
525
- "IPY_MODEL_0e19c2d5b060490399efbfcda773e9ba"
526
- ]
527
- }
528
- },
529
- "c3b054972a6145d1ad03ca938a7ade9c": {
530
- "model_module": "@jupyter-widgets/base",
531
- "model_name": "LayoutModel",
532
- "state": {
533
- "_view_name": "LayoutView",
534
- "grid_template_rows": null,
535
- "right": null,
536
- "justify_content": null,
537
- "_view_module": "@jupyter-widgets/base",
538
- "overflow": null,
539
- "_model_module_version": "1.2.0",
540
- "_view_count": null,
541
- "flex_flow": null,
542
- "width": null,
543
- "min_width": null,
544
- "border": null,
545
- "align_items": null,
546
- "bottom": null,
547
- "_model_module": "@jupyter-widgets/base",
548
- "top": null,
549
- "grid_column": null,
550
- "overflow_y": null,
551
- "overflow_x": null,
552
- "grid_auto_flow": null,
553
- "grid_area": null,
554
- "grid_template_columns": null,
555
- "flex": null,
556
- "_model_name": "LayoutModel",
557
- "justify_items": null,
558
- "grid_row": null,
559
- "max_height": null,
560
- "align_content": null,
561
- "visibility": null,
562
- "align_self": null,
563
- "height": null,
564
- "min_height": null,
565
- "padding": null,
566
- "grid_auto_rows": null,
567
- "grid_gap": null,
568
- "max_width": null,
569
- "order": null,
570
- "_view_module_version": "1.2.0",
571
- "grid_template_areas": null,
572
- "object_position": null,
573
- "object_fit": null,
574
- "grid_auto_columns": null,
575
- "margin": null,
576
- "display": null,
577
- "left": null
578
- }
579
- },
580
- "644f4a69db534dd4a11172e5d010e8fe": {
581
- "model_module": "@jupyter-widgets/controls",
582
- "model_name": "FloatProgressModel",
583
- "state": {
584
- "_view_name": "ProgressView",
585
- "style": "IPY_MODEL_109479db406d4085acff84904cdac4ef",
586
- "_dom_classes": [],
587
- "description": "Downloading: 100%",
588
- "_model_name": "FloatProgressModel",
589
- "bar_style": "success",
590
- "max": 1355256,
591
- "_view_module": "@jupyter-widgets/controls",
592
- "_model_module_version": "1.5.0",
593
- "value": 1355256,
594
- "_view_count": null,
595
- "_view_module_version": "1.5.0",
596
- "orientation": "horizontal",
597
- "min": 0,
598
- "description_tooltip": null,
599
- "_model_module": "@jupyter-widgets/controls",
600
- "layout": "IPY_MODEL_1fb7fd7b44bf4b3a9e949f025487ff47"
601
- }
602
- },
603
- "0e19c2d5b060490399efbfcda773e9ba": {
604
- "model_module": "@jupyter-widgets/controls",
605
- "model_name": "HTMLModel",
606
- "state": {
607
- "_view_name": "HTMLView",
608
- "style": "IPY_MODEL_01c30370e3ff4b5caf9a3369841ad597",
609
- "_dom_classes": [],
610
- "description": "",
611
- "_model_name": "HTMLModel",
612
- "placeholder": "​",
613
- "_view_module": "@jupyter-widgets/controls",
614
- "_model_module_version": "1.5.0",
615
- "value": " 1.36M/1.36M [00:00<00:00, 1.73MB/s]",
616
- "_view_count": null,
617
- "_view_module_version": "1.5.0",
618
- "description_tooltip": null,
619
- "_model_module": "@jupyter-widgets/controls",
620
- "layout": "IPY_MODEL_d19eedc2acce48d4be7189b422b5fcb9"
621
- }
622
- },
623
- "109479db406d4085acff84904cdac4ef": {
624
- "model_module": "@jupyter-widgets/controls",
625
- "model_name": "ProgressStyleModel",
626
- "state": {
627
- "_view_name": "StyleView",
628
- "_model_name": "ProgressStyleModel",
629
- "description_width": "initial",
630
- "_view_module": "@jupyter-widgets/base",
631
- "_model_module_version": "1.5.0",
632
- "_view_count": null,
633
- "_view_module_version": "1.2.0",
634
- "bar_color": null,
635
- "_model_module": "@jupyter-widgets/controls"
636
- }
637
- },
638
- "1fb7fd7b44bf4b3a9e949f025487ff47": {
639
- "model_module": "@jupyter-widgets/base",
640
- "model_name": "LayoutModel",
641
- "state": {
642
- "_view_name": "LayoutView",
643
- "grid_template_rows": null,
644
- "right": null,
645
- "justify_content": null,
646
- "_view_module": "@jupyter-widgets/base",
647
- "overflow": null,
648
- "_model_module_version": "1.2.0",
649
- "_view_count": null,
650
- "flex_flow": null,
651
- "width": null,
652
- "min_width": null,
653
- "border": null,
654
- "align_items": null,
655
- "bottom": null,
656
- "_model_module": "@jupyter-widgets/base",
657
- "top": null,
658
- "grid_column": null,
659
- "overflow_y": null,
660
- "overflow_x": null,
661
- "grid_auto_flow": null,
662
- "grid_area": null,
663
- "grid_template_columns": null,
664
- "flex": null,
665
- "_model_name": "LayoutModel",
666
- "justify_items": null,
667
- "grid_row": null,
668
- "max_height": null,
669
- "align_content": null,
670
- "visibility": null,
671
- "align_self": null,
672
- "height": null,
673
- "min_height": null,
674
- "padding": null,
675
- "grid_auto_rows": null,
676
- "grid_gap": null,
677
- "max_width": null,
678
- "order": null,
679
- "_view_module_version": "1.2.0",
680
- "grid_template_areas": null,
681
- "object_position": null,
682
- "object_fit": null,
683
- "grid_auto_columns": null,
684
- "margin": null,
685
- "display": null,
686
- "left": null
687
- }
688
- },
689
- "01c30370e3ff4b5caf9a3369841ad597": {
690
- "model_module": "@jupyter-widgets/controls",
691
- "model_name": "DescriptionStyleModel",
692
- "state": {
693
- "_view_name": "StyleView",
694
- "_model_name": "DescriptionStyleModel",
695
- "description_width": "",
696
- "_view_module": "@jupyter-widgets/base",
697
- "_model_module_version": "1.5.0",
698
- "_view_count": null,
699
- "_view_module_version": "1.2.0",
700
- "_model_module": "@jupyter-widgets/controls"
701
- }
702
- },
703
- "d19eedc2acce48d4be7189b422b5fcb9": {
704
- "model_module": "@jupyter-widgets/base",
705
- "model_name": "LayoutModel",
706
- "state": {
707
- "_view_name": "LayoutView",
708
- "grid_template_rows": null,
709
- "right": null,
710
- "justify_content": null,
711
- "_view_module": "@jupyter-widgets/base",
712
- "overflow": null,
713
- "_model_module_version": "1.2.0",
714
- "_view_count": null,
715
- "flex_flow": null,
716
- "width": null,
717
- "min_width": null,
718
- "border": null,
719
- "align_items": null,
720
- "bottom": null,
721
- "_model_module": "@jupyter-widgets/base",
722
- "top": null,
723
- "grid_column": null,
724
- "overflow_y": null,
725
- "overflow_x": null,
726
- "grid_auto_flow": null,
727
- "grid_area": null,
728
- "grid_template_columns": null,
729
- "flex": null,
730
- "_model_name": "LayoutModel",
731
- "justify_items": null,
732
- "grid_row": null,
733
- "max_height": null,
734
- "align_content": null,
735
- "visibility": null,
736
- "align_self": null,
737
- "height": null,
738
- "min_height": null,
739
- "padding": null,
740
- "grid_auto_rows": null,
741
- "grid_gap": null,
742
- "max_width": null,
743
- "order": null,
744
- "_view_module_version": "1.2.0",
745
- "grid_template_areas": null,
746
- "object_position": null,
747
- "object_fit": null,
748
- "grid_auto_columns": null,
749
- "margin": null,
750
- "display": null,
751
- "left": null
752
- }
753
- }
754
- }
755
- }
756
- },
757
- "cells": [
758
- {
759
- "cell_type": "code",
760
- "metadata": {
761
- "id": "hYCVkKKAwSjV"
762
- },
763
- "source": [
764
- "%%capture\n",
765
- "!pip install transformers\n",
766
- "!pip install datasets\n",
767
- "!pip install --upgrade git+https://github.com/google/flax.git"
768
- ],
769
- "execution_count": 1,
770
- "outputs": []
771
- },
772
- {
773
- "cell_type": "code",
774
- "metadata": {
775
- "id": "2gcm5rxByOXO",
776
- "colab": {
777
- "base_uri": "https://localhost:8080/",
778
- "height": 164,
779
- "referenced_widgets": [
780
- "1b266a2c1cf646a392a46e39586282b3",
781
- "8ecfcf14981c4d82b5d9d3839a496f0b",
782
- "16b07572ac0d46798b2c2a292c3f9143",
783
- "cf412ff73fc647908154abc9b2847f38",
784
- "6ebc21286ae843e5b9ba4df8f4cebfe0",
785
- "48246be80e82429da2d48f9d4a1aaf0a",
786
- "5754900e885d4f509ede058b186fcab6",
787
- "d0434381119c46489e17fcbccd9755ea",
788
- "73c4b8bc05f64477aa03d767f4483795",
789
- "6123827ad5964b4b8a17aaca618b4768",
790
- "5327d425e74d4a599214282b9b70d58b",
791
- "974490d04f18407f9f5a5785b2802c0a",
792
- "c3cc1723c39a4d74b2ab83bd23b5fcce",
793
- "391d59bf8d2845f88a83dc25c7cf89f3",
794
- "d60fa9fe71444784b78bdfba6ed6a9e1",
795
- "41a3b55e5e264b85ada9558e5777790f",
796
- "aa4d6e2e9ac44e9bb40b7daccc91ee83",
797
- "c3b054972a6145d1ad03ca938a7ade9c",
798
- "644f4a69db534dd4a11172e5d010e8fe",
799
- "0e19c2d5b060490399efbfcda773e9ba",
800
- "109479db406d4085acff84904cdac4ef",
801
- "1fb7fd7b44bf4b3a9e949f025487ff47",
802
- "01c30370e3ff4b5caf9a3369841ad597",
803
- "d19eedc2acce48d4be7189b422b5fcb9"
804
- ]
805
- },
806
- "outputId": "5814323f-d04d-408c-e833-8522806ea73b"
807
- },
808
- "source": [
809
- "import jax\n",
810
- "from transformers.modeling_flax_utils import FlaxPreTrainedModel\n",
811
- "import flax.linen as nn\n",
812
- "import jax.numpy as jnp\n",
813
- "from transformers import GPT2Config\n",
814
- "from transformers import FlaxGPT2PreTrainedModel\n",
815
- "from transformers import FlaxGPT2Model\n",
816
- "from transformers import GPT2Tokenizer\n",
817
- "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\",pad_token='<|endoftext|>')"
818
- ],
819
- "execution_count": 2,
820
- "outputs": [
821
- {
822
- "output_type": "display_data",
823
- "data": {
824
- "application/vnd.jupyter.widget-view+json": {
825
- "model_id": "1b266a2c1cf646a392a46e39586282b3",
826
- "version_minor": 0,
827
- "version_major": 2
828
- },
829
- "text/plain": [
830
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…"
831
- ]
832
- },
833
- "metadata": {
834
- "tags": []
835
- }
836
- },
837
- {
838
- "output_type": "stream",
839
- "text": [
840
- "\n"
841
- ],
842
- "name": "stdout"
843
- },
844
- {
845
- "output_type": "display_data",
846
- "data": {
847
- "application/vnd.jupyter.widget-view+json": {
848
- "model_id": "73c4b8bc05f64477aa03d767f4483795",
849
- "version_minor": 0,
850
- "version_major": 2
851
- },
852
- "text/plain": [
853
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…"
854
- ]
855
- },
856
- "metadata": {
857
- "tags": []
858
- }
859
- },
860
- {
861
- "output_type": "stream",
862
- "text": [
863
- "\n"
864
- ],
865
- "name": "stdout"
866
- },
867
- {
868
- "output_type": "display_data",
869
- "data": {
870
- "application/vnd.jupyter.widget-view+json": {
871
- "model_id": "aa4d6e2e9ac44e9bb40b7daccc91ee83",
872
- "version_minor": 0,
873
- "version_major": 2
874
- },
875
- "text/plain": [
876
- "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…"
877
- ]
878
- },
879
- "metadata": {
880
- "tags": []
881
- }
882
- },
883
- {
884
- "output_type": "stream",
885
- "text": [
886
- "\n"
887
- ],
888
- "name": "stdout"
889
- }
890
- ]
891
- },
892
- {
893
- "cell_type": "code",
894
- "metadata": {
895
- "id": "GDokS6VEJI6C"
896
- },
897
- "source": [
898
- "#inputs = tokenizer([\"JAX/Flax is amazing \",\"tensorflow is also good\"],[\"pytorch is better\",\"keras is the best\"],return_tensors='jax',padding='max_length',max_length=30)"
899
- ],
900
- "execution_count": 3,
901
- "outputs": []
902
- },
903
- {
904
- "cell_type": "code",
905
- "metadata": {
906
- "id": "hWiMk1TzyYim"
907
- },
908
- "source": [
909
- "class FlaxGPT2ForMultipleChoiceModule(nn.Module):\n",
910
- " config:GPT2Config\n",
911
- " dtype: jnp.dtype = jnp.float32\n",
912
- " def setup(self):\n",
913
- " self.gpt2 = FlaxGPT2Model(config=self.config, dtype=self.dtype)\n",
914
- " self.dropout = nn.Dropout(rate=0.2)\n",
915
- " self.classifier = nn.Dense(4, dtype=self.dtype)\n",
916
- "\n",
917
- " def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):\n",
918
- " batch_size = input_ids.shape[0]\n",
919
- "\n",
920
- " rng=jax.random.PRNGKey(0)\n",
921
- " _, dropout_rng = jax.random.split(rng)\n",
922
- "\n",
923
- " outputs=self.gpt2(input_ids, attention_mask,position_ids,return_dict=return_dict)\n",
924
- " \n",
925
- "\n",
926
- " hidden_states = outputs[0]\n",
927
- "\n",
928
- " \n",
929
- " hidden_states= jnp.mean(hidden_states, axis=1)\n",
930
- "\n",
931
- " print(hidden_states.shape)\n",
932
- " \n",
933
- " hidden_states=hidden_states.reshape(batch_size,-1) #(32,8,768)->(32,8*768)\n",
934
- "\n",
935
- " dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng)\n",
936
- "\n",
937
- " print(dropout_output.shape)\n",
938
- "\n",
939
- " logits = self.classifier(dropout_output)\n",
940
- " reshaped_logits = logits.reshape(-1, 4) #(32,4)\n",
941
- " if not return_dict:\n",
942
- " return (reshaped_logits,) + outputs[2:]\n",
943
- " return reshaped_logits"
944
- ],
945
- "execution_count": 7,
946
- "outputs": []
947
- },
948
- {
949
- "cell_type": "code",
950
- "metadata": {
951
- "id": "u1j00Ck255BC"
952
- },
953
- "source": [
954
- "class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):\n",
955
- " module_class = FlaxGPT2ForMultipleChoiceModule"
956
- ],
957
- "execution_count": 8,
958
- "outputs": []
959
- },
960
- {
961
- "cell_type": "code",
962
- "metadata": {
963
- "id": "h2MrRgKTRxZO",
964
- "colab": {
965
- "base_uri": "https://localhost:8080/"
966
- },
967
- "outputId": "5a0fcc68-ca39-4df0-c854-734125d65f53"
968
- },
969
- "source": [
970
- "model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2') # getting warning"
971
- ],
972
- "execution_count": 9,
973
- "outputs": [
974
- {
975
- "output_type": "stream",
976
- "text": [
977
- "(1, 768)\n",
978
- "(1, 768)\n"
979
- ],
980
- "name": "stdout"
981
- },
982
- {
983
- "output_type": "stream",
984
- "text": [
985
- "Some weights of the model checkpoint at gpt2 were not used when initializing FlaxGPT2ForMultipleChoice: {('h', '1', 'ln_1', 'bias'), ('h', '6', 'ln_1', 'scale'), ('h', '1', 'attn', 'c_proj', 'kernel'), ('h', '11', 'mlp', 'c_fc', 'bias'), ('h', '7', 'ln_1', 'bias'), ('h', '5', 'ln_2', 'bias'), ('h', '10', 'ln_2', 'scale'), ('h', '4', 'mlp', 'c_proj', 'kernel'), ('h', '0', 'mlp', 'c_proj', 'bias'), ('h', '0', 'ln_1', 'bias'), ('h', '0', 'mlp', 'c_fc', 'kernel'), ('wpe', 'embedding'), ('h', '3', 'ln_1', 'scale'), ('h', '2', 'ln_1', 'scale'), ('h', '3', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'ln_1', 'scale'), ('h', '8', 'mlp', 'c_proj', 'kernel'), ('h', '7', 'mlp', 'c_proj', 'kernel'), ('h', '3', 'ln_2', 'bias'), ('h', '9', 'attn', 'c_attn', 'kernel'), ('h', '0', 'mlp', 'c_fc', 'bias'), ('h', '3', 'attn', 'c_proj', 'bias'), ('h', '0', 'ln_1', 'scale'), ('h', '3', 'attn', 'c_attn', 'kernel'), ('h', '0', 'mlp', 'c_proj', 'kernel'), ('h', '5', 'ln_1', 'bias'), ('h', '7', 'attn', 'c_attn', 'bias'), ('h', '1', 'ln_2', 'bias'), ('h', '11', 'ln_2', 'scale'), ('h', '7', 'ln_2', 'bias'), ('h', '9', 'attn', 'c_proj', 'kernel'), ('h', '0', 'ln_2', 'bias'), ('h', '2', 'ln_2', 'scale'), ('h', '11', 'attn', 'c_attn', 'kernel'), ('h', '8', 'attn', 'c_proj', 'kernel'), ('h', '4', 'attn', 'c_attn', 'kernel'), ('h', '5', 'ln_1', 'scale'), ('h', '4', 'ln_1', 'bias'), ('h', '8', 'ln_2', 'bias'), ('h', '1', 'mlp', 'c_fc', 'kernel'), ('h', '9', 'ln_2', 'scale'), ('h', '1', 'mlp', 'c_proj', 'bias'), ('h', '2', 'mlp', 'c_proj', 'kernel'), ('h', '9', 'attn', 'c_proj', 'bias'), ('h', '11', 'ln_2', 'bias'), ('h', '6', 'mlp', 'c_proj', 'bias'), ('h', '3', 'ln_1', 'bias'), ('h', '1', 'attn', 'c_attn', 'kernel'), ('h', '9', 'ln_1', 'scale'), ('h', '10', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_proj', 'kernel'), ('h', '2', 'attn', 'c_proj', 'kernel'), ('h', '0', 'attn', 'c_proj', 'kernel'), ('h', '6', 'attn', 'c_attn', 'kernel'), ('h', '4', 'mlp', 'c_fc', 'bias'), ('h', '3', 'attn', 'c_attn', 'bias'), ('h', '3', 'attn', 'c_proj', 'kernel'), ('h', '11', 'mlp', 'c_proj', 'bias'), ('h', '9', 'attn', 'c_attn', 'bias'), ('h', '7', 'mlp', 'c_proj', 'bias'), ('h', '7', 'mlp', 'c_fc', 'bias'), ('h', '6', 'attn', 'c_attn', 'bias'), ('h', '5', 'mlp', 'c_fc', 'kernel'), ('h', '0', 'attn', 'c_proj', 'bias'), ('h', '2', 'attn', 'c_proj', 'bias'), ('h', '10', 'attn', 'c_attn', 'kernel'), ('h', '10', 'mlp', 'c_proj', 'bias'), ('h', '1', 'attn', 'c_attn', 'bias'), ('h', '11', 'ln_1', 'bias'), ('h', '4', 'ln_2', 'bias'), ('h', '8', 'ln_1', 'bias'), ('h', '11', 'attn', 'c_proj', 'kernel'), ('h', '9', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'ln_2', 'scale'), ('h', '9', 'mlp', 'c_proj', 'kernel'), ('h', '11', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_fc', 'bias'), ('h', '6', 'attn', 'c_proj', 'kernel'), ('h', '0', 'ln_2', 'scale'), ('h', '2', 'ln_2', 'bias'), ('h', '3', 'mlp', 'c_proj', 'bias'), ('h', '5', 'mlp', 'c_proj', 'kernel'), ('h', '8', 'mlp', 'c_fc', 'bias'), ('h', '9', 'mlp', 'c_proj', 'bias'), ('h', '9', 'mlp', 'c_fc', 'bias'), ('h', '8', 'mlp', 'c_fc', 'kernel'), ('h', '9', 'ln_1', 'bias'), ('h', '10', 'ln_1', 'scale'), ('h', '6', 'ln_2', 'bias'), ('h', '2', 'mlp', 'c_fc', 'kernel'), ('h', '4', 'attn', 'c_proj', 'bias'), ('h', '1', 'ln_2', 'scale'), ('h', '5', 'mlp', 'c_fc', 'bias'), ('h', '7', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'attn', 'c_proj', 'bias'), ('h', '5', 'attn', 'c_proj', 'kernel'), ('h', '2', 'mlp', 'c_fc', 'bias'), ('h', '6', 'ln_2', 'scale'), ('h', '11', 'ln_1', 'scale'), ('h', '4', 'mlp', 'c_fc', 'kernel'), ('h', '2', 'ln_1', 'bias'), ('h', '9', 'ln_2', 'bias'), ('h', '11', 'mlp', 'c_fc', 'kernel'), ('h', '1', 'attn', 'c_proj', 'bias'), ('h', '4', 'ln_2', 'scale'), ('h', '8', 'ln_1', 'scale'), ('h', '6', 'attn', 'c_proj', 'bias'), ('h', '5', 'attn', 'c_attn', 'kernel'), ('h', '3', 'ln_2', 'scale'), ('h', '8', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_fc', 'kernel'), ('h', '1', 'ln_1', 'scale'), ('h', '10', 'attn', 'c_proj', 'bias'), ('h', '6', 'ln_1', 'bias'), ('h', '0', 'attn', 'c_attn', 'kernel'), ('wte', 'embedding'), ('h', '6', 'mlp', 'c_fc', 'kernel'), ('h', '4', 'attn', 'c_attn', 'bias'), ('h', '10', 'ln_2', 'bias'), ('h', '8', 'attn', 'c_proj', 'bias'), ('h', '11', 'attn', 'c_proj', 'bias'), ('h', '8', 'attn', 'c_attn', 'kernel'), ('h', '5', 'attn', 'c_attn', 'bias'), ('h', '5', 'ln_2', 'scale'), ('h', '2', 'attn', 'c_attn', 'bias'), ('ln_f', 'scale'), ('h', '7', 'attn', 'c_attn', 'kernel'), ('h', '4', 'ln_1', 'scale'), ('h', '8', 'ln_2', 'scale'), ('h', '11', 'mlp', 'c_proj', 'kernel'), ('h', '5', 'attn', 'c_proj', 'bias'), ('h', '7', 'attn', 'c_proj', 'kernel'), ('h', '8', 'mlp', 'c_proj', 'bias'), ('h', '3', 'mlp', 'c_fc', 'bias'), ('h', '10', 'ln_1', 'bias'), ('h', '2', 'attn', 'c_attn', 'kernel'), ('h', '6', 'mlp', 'c_proj', 'kernel'), ('h', '4', 'attn', 'c_proj', 'kernel'), ('h', '1', 'mlp', 'c_proj', 'kernel'), ('h', '2', 'mlp', 'c_proj', 'bias'), ('h', '1', 'mlp', 'c_fc', 'bias'), ('h', '4', 'mlp', 'c_proj', 'bias'), ('ln_f', 'bias'), ('h', '6', 'mlp', 'c_fc', 'bias'), ('h', '0', 'attn', 'c_attn', 'bias'), ('h', '10', 'attn', 'c_proj', 'kernel'), ('h', '5', 'mlp', 'c_proj', 'bias'), ('h', '3', 'mlp', 'c_proj', 'kernel')}\n",
986
- "- This IS expected if you are initializing FlaxGPT2ForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
987
- "- This IS NOT expected if you are initializing FlaxGPT2ForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
988
- "Some weights of FlaxGPT2ForMultipleChoice were not initialized from the model checkpoint at gpt2 and are newly initialized: {('classifier', 'bias'), ('classifier', 'kernel')}\n",
989
- "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
990
- ],
991
- "name": "stderr"
992
- }
993
- ]
994
- },
995
- {
996
- "cell_type": "code",
997
- "metadata": {
998
- "id": "CdSuQK9pRmw-"
999
- },
1000
- "source": [
1001
- "input_ids=jnp.ones((4,5,6))\n",
1002
- "attention_mask=jnp.ones((4,5,6))"
1003
- ],
1004
- "execution_count": 10,
1005
- "outputs": []
1006
- },
1007
- {
1008
- "cell_type": "code",
1009
- "metadata": {
1010
- "id": "d3Bu38KTkwWs",
1011
- "colab": {
1012
- "base_uri": "https://localhost:8080/",
1013
- "height": 300
1014
- },
1015
- "outputId": "5470ccc9-6d49-427c-ad8e-5162343acfde"
1016
- },
1017
- "source": [
1018
- "out1 = model(input_ids, attention_mask) #GPT2 will not take (batch_size,num_choice,sequence_length)"
1019
- ],
1020
- "execution_count": 11,
1021
- "outputs": [
1022
- {
1023
- "output_type": "error",
1024
- "ename": "ValueError",
1025
- "evalue": "ignored",
1026
- "traceback": [
1027
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1028
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
1029
- "\u001b[0;32m<ipython-input-11-7491141e6756>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mout1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#GPT2 will not take (batch_size,num_choice,sequence_length)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
1030
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 371\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 372\u001b[0;31m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msequence_length\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 373\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mposition_ids\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
1031
- "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
1032
- ]
1033
- }
1034
- ]
1035
- },
1036
- {
1037
- "cell_type": "code",
1038
- "metadata": {
1039
- "id": "VZPlQfkhgLJd",
1040
- "colab": {
1041
- "base_uri": "https://localhost:8080/"
1042
- },
1043
- "outputId": "7948688b-be77-4e9a-fc06-8644a2614d42"
1044
- },
1045
- "source": [
1046
- "print(out1)"
1047
- ],
1048
- "execution_count": null,
1049
- "outputs": [
1050
- {
1051
- "output_type": "stream",
1052
- "text": [
1053
- "[[ 1.1391759 -0.01598702 0.55463445 0.36025363]\n",
1054
- " [ 0.32208228 0.37667227 0.87823874 0.19541818]\n",
1055
- " [ 0.76971424 0.7187787 0.68642044 -0.31461257]\n",
1056
- " [ 1.2375658 0.03325981 0.00153449 0.12019679]]\n"
1057
- ],
1058
- "name": "stdout"
1059
- }
1060
- ]
1061
- },
1062
- {
1063
- "cell_type": "code",
1064
- "metadata": {
1065
- "id": "fgkIcD-mZWP7"
1066
- },
1067
- "source": [
1068
- ""
1069
- ],
1070
- "execution_count": null,
1071
- "outputs": []
1072
- }
1073
- ]
1074
- }
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
Untitled330.ipynb DELETED
@@ -1,470 +0,0 @@
1
- {
2
- "nbformat": 4,
3
- "nbformat_minor": 0,
4
- "metadata": {
5
- "colab": {
6
- "name": "Untitled330.ipynb",
7
- "provenance": [],
8
- "collapsed_sections": []
9
- },
10
- "kernelspec": {
11
- "name": "python3",
12
- "display_name": "Python 3"
13
- },
14
- "language_info": {
15
- "name": "python"
16
- }
17
- },
18
- "cells": [
19
- {
20
- "cell_type": "code",
21
- "metadata": {
22
- "id": "Ii2x731Ta8fu"
23
- },
24
- "source": [
25
- "%%capture\n",
26
- "!pip install transformers\n",
27
- "!pip install datasets\n",
28
- "!pip install --upgrade git+https://github.com/google/flax.git"
29
- ],
30
- "execution_count": 1,
31
- "outputs": []
32
- },
33
- {
34
- "cell_type": "code",
35
- "metadata": {
36
- "id": "_9NMPFKua9hr"
37
- },
38
- "source": [
39
- "import jax\n",
40
- "from transformers.modeling_flax_utils import FlaxPreTrainedModel\n",
41
- "import flax.linen as nn\n",
42
- "import jax.numpy as jnp\n",
43
- "from transformers import GPT2Config\n",
44
- "#from transformers import FlaxGPT2PreTrainedModel\n",
45
- "from transformers import FlaxGPT2Model\n",
46
- "import jax.numpy as jnp\n",
47
- "from transformers import GPT2Tokenizer\n",
48
- "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\",pad_token='<|endoftext|>') \n",
49
- "from typing import Any, Optional, Tuple\n",
50
- "from flax.core.frozen_dict import FrozenDict, unfreeze\n",
51
- "from transformers import file_utils\n",
52
- "from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward\n",
53
- "from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2BlockCollection\n",
54
- "from transformers.modeling_flax_outputs import FlaxBaseModelOutput"
55
- ],
56
- "execution_count": null,
57
- "outputs": []
58
- },
59
- {
60
- "cell_type": "code",
61
- "metadata": {
62
- "id": "dqkcoBOccszd"
63
- },
64
- "source": [
65
- "GPT2_START_DOCSTRING = r\"\"\"\n",
66
- " This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the\n",
67
- " generic methods the library implements for all its model (such as downloading or saving, resizing the input\n",
68
- " embeddings, pruning heads etc.)\n",
69
- " This model is also a Flax Linen `flax.nn.Module\n",
70
- " <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax\n",
71
- " Module and refer to the Flax documentation for all matter related to general usage and behavior.\n",
72
- " Finally, this model supports inherent JAX features such as:\n",
73
- " - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__\n",
74
- " - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__\n",
75
- " - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__\n",
76
- " - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__\n",
77
- " Parameters:\n",
78
- " config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.\n",
79
- " Initializing with a config file does not load the weights associated with the model, only the\n",
80
- " configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the\n",
81
- " model weights.\n",
82
- "\"\"\"\n",
83
- "\n",
84
- "GPT2_INPUTS_DOCSTRING = r\"\"\"\n",
85
- " Args:\n",
86
- " input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size,input_ids_length)`):\n",
87
- " :obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.\n",
88
- " Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See\n",
89
- " :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for\n",
90
- " details.\n",
91
- " `What are input IDs? <../glossary.html#input-ids>`__\n",
92
- " attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
93
- " Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:\n",
94
- " - 1 for tokens that are **not masked**,\n",
95
- " - 0 for tokens that are **masked**.\n",
96
- " `What are attention masks? <../glossary.html#attention-mask>`__\n",
97
- " position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
98
- " Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,\n",
99
- " config.max_position_embeddings - 1]``.\n",
100
- " past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):\n",
101
- " Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast\n",
102
- " auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.\n",
103
- " output_attentions (:obj:`bool`, `optional`):\n",
104
- " Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned\n",
105
- " tensors for more detail.\n",
106
- " output_hidden_states (:obj:`bool`, `optional`):\n",
107
- " Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for\n",
108
- " more detail.\n",
109
- " return_dict (:obj:`bool`, `optional`):\n",
110
- " Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.\n",
111
- "\"\"\""
112
- ],
113
- "execution_count": 3,
114
- "outputs": []
115
- },
116
- {
117
- "cell_type": "code",
118
- "metadata": {
119
- "id": "NX-Z5iCMbKL5"
120
- },
121
- "source": [
122
- "class FlaxGGGPreTrainedModel(FlaxPreTrainedModel):\n",
123
- " \"\"\"\n",
124
- " An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n",
125
- " models.\n",
126
- " \"\"\"\n",
127
- "\n",
128
- " config_class = GPT2Config\n",
129
- " base_model_prefix = \"transformer\"\n",
130
- " module_class: nn.Module = None\n",
131
- "\n",
132
- " def __init__(\n",
133
- " self,\n",
134
- " config: GPT2Config,\n",
135
- " input_shape: Tuple = (1,1),\n",
136
- " seed: int = 0,\n",
137
- " dtype: jnp.dtype = jnp.float32,\n",
138
- " **kwargs,\n",
139
- " ):\n",
140
- " \n",
141
- " module = self.module_class(config=config, dtype=dtype, **kwargs)\n",
142
- " super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)\n",
143
- "\n",
144
- " def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:\n",
145
- " # init input tensors\n",
146
- " input_ids = jnp.zeros(input_shape, dtype=\"i4\")\n",
147
- " attention_mask = jnp.ones_like(input_ids)\n",
148
- " \n",
149
- " position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)\n",
150
- "\n",
151
- " params_rng, dropout_rng = jax.random.split(rng)\n",
152
- " rngs = {\"params\": params_rng, \"dropout\": dropout_rng}\n",
153
- "\n",
154
- " return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)[\"params\"]\n",
155
- "\n",
156
- " def init_cache(self, batch_size, max_length):\n",
157
- " r\"\"\"\n",
158
- " Args:\n",
159
- " batch_size (:obj:`int`):\n",
160
- " batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.\n",
161
- " max_length (:obj:`int`):\n",
162
- " maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized\n",
163
- " cache.\n",
164
- " \"\"\"\n",
165
- " # init input variables to retrieve cache\n",
166
- " input_ids = jnp.ones((batch_size, max_length))\n",
167
- " attention_mask = jnp.ones_like(input_ids)\n",
168
- " position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)\n",
169
- "\n",
170
- " init_variables = self.module.init(\n",
171
- " jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True\n",
172
- " )\n",
173
- " return init_variables[\"cache\"]\n",
174
- "\n",
175
- " @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)\n",
176
- " def __call__(\n",
177
- " self,\n",
178
- " input_ids,\n",
179
- " attention_mask=None,\n",
180
- " position_ids=None,\n",
181
- " params: dict = None,\n",
182
- " past_key_values: dict = None,\n",
183
- " dropout_rng: jax.random.PRNGKey = None,\n",
184
- " train: bool = False,\n",
185
- " output_attentions: Optional[bool] = None,\n",
186
- " output_hidden_states: Optional[bool] = None,\n",
187
- " return_dict: Optional[bool] = None,\n",
188
- " ):\n",
189
- " output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n",
190
- " output_hidden_states = (\n",
191
- " output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n",
192
- " )\n",
193
- " return_dict = return_dict if return_dict is not None else self.config.return_dict\n",
194
- " print(input_ids.shape)\n",
195
- "\n",
196
- " # batch_size, num_choices,sequence_length = input_ids.shape\n",
197
- "\n",
198
- " if position_ids is None:\n",
199
- " if past_key_values is not None:\n",
200
- " raise ValueError(\"Make sure to provide `position_ids` when passing `past_key_values`.\")\n",
201
- " \n",
202
- " position_ids=jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)\n",
203
- "\n",
204
- " # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))\n",
205
- "\n",
206
- " if attention_mask is None:\n",
207
- " attention_mask = jnp.ones((input_ids))\n",
208
- " print('attn not')\n",
209
- "\n",
210
- " # Handle any PRNG if needed\n",
211
- " rngs = {}\n",
212
- " if dropout_rng is not None:\n",
213
- " rngs[\"dropout\"] = dropout_rng\n",
214
- "\n",
215
- " inputs = {\"params\": params or self.params}\n",
216
- "\n",
217
- " # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module\n",
218
- " if past_key_values:\n",
219
- " inputs[\"cache\"] = past_key_values\n",
220
- " mutable = [\"cache\"]\n",
221
- " else:\n",
222
- " mutable = False\n",
223
- "\n",
224
- " outputs = self.module.apply(\n",
225
- " inputs,\n",
226
- " jnp.array(input_ids, dtype=\"i4\"),\n",
227
- " jnp.array(attention_mask, dtype=\"i4\"),\n",
228
- " jnp.array(position_ids, dtype=\"i4\"),\n",
229
- " not train,\n",
230
- " False,\n",
231
- " output_attentions,\n",
232
- " output_hidden_states,\n",
233
- " return_dict,\n",
234
- " rngs=rngs,\n",
235
- " mutable=mutable,\n",
236
- " )\n",
237
- " print('cache')\n",
238
- "\n",
239
- " # add updated cache to model output\n",
240
- " if past_key_values is not None and return_dict:\n",
241
- " outputs, past_key_values = outputs\n",
242
- " outputs[\"past_key_values\"] = unfreeze(past_key_values[\"cache\"])\n",
243
- " return outputs\n",
244
- " elif past_key_values is not None and not return_dict:\n",
245
- " outputs, past_key_values = outputs\n",
246
- " outputs = outputs[:1] + (unfreeze(past_key_values[\"cache\"]),) + outputs[1:]\n",
247
- "\n",
248
- " return outputs"
249
- ],
250
- "execution_count": 6,
251
- "outputs": []
252
- },
253
- {
254
- "cell_type": "code",
255
- "metadata": {
256
- "id": "4vRAWll2bwQQ"
257
- },
258
- "source": [
259
- "class FlaxGGGModule(nn.Module):\n",
260
- " config: GPT2Config\n",
261
- " dtype: jnp.dtype = jnp.float32\n",
262
- "\n",
263
- " def setup(self):\n",
264
- " self.embed_dim = self.config.hidden_size\n",
265
- "\n",
266
- " self.wte = nn.Embed(\n",
267
- " self.config.vocab_size,\n",
268
- " self.embed_dim,\n",
269
- " embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n",
270
- " dtype=self.dtype,\n",
271
- " )\n",
272
- " self.wpe = nn.Embed(\n",
273
- " self.config.max_position_embeddings,\n",
274
- " self.embed_dim,\n",
275
- " embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n",
276
- " dtype=self.dtype,\n",
277
- " )\n",
278
- " self.dropout = nn.Dropout(rate=self.config.embd_pdrop)\n",
279
- " self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)\n",
280
- " self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)\n",
281
- "\n",
282
- " def __call__(\n",
283
- " self,\n",
284
- " input_ids,\n",
285
- " attention_mask,\n",
286
- " position_ids,\n",
287
- " deterministic=True,\n",
288
- " init_cache: bool = False,\n",
289
- " output_attentions: bool = False,\n",
290
- " output_hidden_states: bool = False,\n",
291
- " return_dict: bool = True,\n",
292
- " ):\n",
293
- " input_embeds = self.wte(input_ids.astype(\"i4\"))\n",
294
- " position_embeds = self.wpe(position_ids.astype(\"i4\"))\n",
295
- " \n",
296
- "\n",
297
- " hidden_states = input_embeds + position_embeds\n",
298
- " hidden_states = self.dropout(hidden_states, deterministic=deterministic)\n",
299
- " outputs = self.h(\n",
300
- " hidden_states,\n",
301
- " attention_mask,\n",
302
- " deterministic=deterministic,\n",
303
- " init_cache=init_cache,\n",
304
- " output_attentions=output_attentions,\n",
305
- " output_hidden_states=output_hidden_states,\n",
306
- " return_dict=return_dict,\n",
307
- " )\n",
308
- "\n",
309
- " hidden_states = outputs[0]\n",
310
- " hidden_states = self.ln_f(hidden_states)\n",
311
- " print('ggg')\n",
312
- " if not return_dict:\n",
313
- " return (hidden_states,) + outputs[1:]\n",
314
- "\n",
315
- " return FlaxBaseModelOutput(\n",
316
- " last_hidden_state=hidden_states,\n",
317
- " hidden_states=outputs.hidden_states,\n",
318
- " attentions=outputs.attentions,)\n",
319
- "class FlaxNewModel(FlaxGGGPreTrainedModel):\n",
320
- " module_class = FlaxGGGModule"
321
- ],
322
- "execution_count": 7,
323
- "outputs": []
324
- },
325
- {
326
- "cell_type": "code",
327
- "metadata": {
328
- "id": "_ljSn6GdedtI"
329
- },
330
- "source": [
331
- "class FlaxGPT2ForMultipleChoiceModule(nn.Module):\n",
332
- " config:GPT2Config\n",
333
- " dtype: jnp.dtype = jnp.float32\n",
334
- " def setup(self):\n",
335
- " self.gpt2 = FlaxNewModel(config=self.config, dtype=self.dtype)\n",
336
- " self.dropout = nn.Dropout(rate=0.2)\n",
337
- " self.classifier = nn.Dense(4, dtype=self.dtype)\n",
338
- "\n",
339
- " def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):\n",
340
- " batch_size = input_ids.shape[0]\n",
341
- " rng=jax.random.PRNGKey(0)\n",
342
- " _, dropout_rng = jax.random.split(rng)\n",
343
- " print('abc')\n",
344
- "\n",
345
- " outputs=self.gpt2(input_ids, attention_mask,position_ids,return_dict=return_dict)\n",
346
- " \n",
347
- "\n",
348
- " hidden_states = outputs[0]\n",
349
- "\n",
350
- " \n",
351
- " hidden_states= jnp.mean(hidden_states, axis=1)\n",
352
- "\n",
353
- " print(hidden_states.shape)\n",
354
- " \n",
355
- " \n",
356
- " hidden_states=hidden_states.reshape(batch_size,-1) #(32,8,768)->(32,8*768)\n",
357
- "\n",
358
- " dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng)\n",
359
- "\n",
360
- " print(dropout_output.shape)\n",
361
- " \n",
362
- "\n",
363
- " logits = self.classifier(dropout_output)\n",
364
- " print('bnv')\n",
365
- " reshaped_logits = logits.reshape(-1, 4) \n",
366
- " #(32,4)\n",
367
- " if not return_dict:\n",
368
- " return (reshaped_logits,) + outputs[2:]\n",
369
- " return reshaped_logits"
370
- ],
371
- "execution_count": 8,
372
- "outputs": []
373
- },
374
- {
375
- "cell_type": "code",
376
- "metadata": {
377
- "id": "M4UPf3Waexq0"
378
- },
379
- "source": [
380
- "class FlaxGPT2ForMultipleChoice(FlaxNewModel):\n",
381
- " module_class = FlaxGPT2ForMultipleChoiceModule"
382
- ],
383
- "execution_count": 9,
384
- "outputs": []
385
- },
386
- {
387
- "cell_type": "code",
388
- "metadata": {
389
- "id": "roQ3vls4e4TH"
390
- },
391
- "source": [
392
- "model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2')"
393
- ],
394
- "execution_count": null,
395
- "outputs": []
396
- },
397
- {
398
- "cell_type": "code",
399
- "metadata": {
400
- "id": "E9qOSaaie417"
401
- },
402
- "source": [
403
- "input_ids=jnp.ones((1,2,11))\n",
404
- "attention_mask=jnp.ones((1,2,11))"
405
- ],
406
- "execution_count": 12,
407
- "outputs": []
408
- },
409
- {
410
- "cell_type": "code",
411
- "metadata": {
412
- "colab": {
413
- "base_uri": "https://localhost:8080/",
414
- "height": 409
415
- },
416
- "id": "am7hYv8auWVy",
417
- "outputId": "0c8192ca-a0ab-432e-d483-46f8a2cc2576"
418
- },
419
- "source": [
420
- "out1 = model(input_ids, attention_mask)"
421
- ],
422
- "execution_count": 13,
423
- "outputs": [
424
- {
425
- "output_type": "stream",
426
- "text": [
427
- "(1, 2, 11)\n",
428
- "attn not\n",
429
- "ggg\n",
430
- "abc\n",
431
- "(1, 2, 11)\n",
432
- "attn not\n"
433
- ],
434
- "name": "stdout"
435
- },
436
- {
437
- "output_type": "error",
438
- "ename": "ValueError",
439
- "evalue": "ignored",
440
- "traceback": [
441
- "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
442
- "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
443
- "\u001b[0;32m<ipython-input-13-6be36035677e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mout1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
444
- "\u001b[0;32m<ipython-input-6-de553f26d169>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m )\n\u001b[1;32m 116\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cache'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
445
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 965\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 966\u001b[0;31m )(variables, *args, **kwargs, rngs=rngs)\n\u001b[0m\u001b[1;32m 967\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 968\u001b[0m def init_with_output(self,\n",
446
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/scope.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(variables, rngs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 685\u001b[0m **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:\n\u001b[1;32m 686\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvariables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtemporary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 687\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 688\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmutable\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 689\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmutable_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
447
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mscope_fn\u001b[0;34m(scope, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1214\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1216\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1217\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1218\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
448
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
449
- "\u001b[0;32m<ipython-input-8-2c21e4c966c8>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, return_dict, deterministic, *args)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'abc'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0moutputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpt2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
450
- "\u001b[0;32m<ipython-input-6-de553f26d169>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m )\n\u001b[1;32m 116\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cache'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
451
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 965\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 966\u001b[0;31m )(variables, *args, **kwargs, rngs=rngs)\n\u001b[0m\u001b[1;32m 967\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 968\u001b[0m def init_with_output(self,\n",
452
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/scope.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(variables, rngs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 685\u001b[0m **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:\n\u001b[1;32m 686\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvariables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtemporary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 687\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 688\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmutable\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 689\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmutable_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
453
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mscope_fn\u001b[0;34m(scope, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1214\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1216\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1217\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1218\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
454
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
455
- "\u001b[0;32m<ipython-input-7-b2eaa3f7b251>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m )\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
456
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
457
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, hidden_states, attention_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0mdeterministic\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdeterministic\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0minit_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 454\u001b[0;31m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 455\u001b[0m )\n\u001b[1;32m 456\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
458
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
459
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, hidden_states, attention_mask, deterministic, init_cache, output_attentions)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mdeterministic\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdeterministic\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0minit_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m )\n\u001b[1;32m 289\u001b[0m \u001b[0;31m# residual connection\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
460
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
461
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, hidden_states, attention_mask, deterministic, init_cache, output_attentions)\u001b[0m\n\u001b[1;32m 177\u001b[0m ):\n\u001b[1;32m 178\u001b[0m \u001b[0mqkv_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_attn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mquery\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mquery\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_heads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
462
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36msplit\u001b[0;34m(ary, indices_or_sections, axis)\u001b[0m\n\u001b[1;32m 1806\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_wraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1807\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices_or_sections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1808\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"split\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices_or_sections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1809\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1810\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_split_on_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
463
- "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_split\u001b[0;34m(op, ary, indices_or_sections, axis)\u001b[0m\n\u001b[1;32m 1798\u001b[0m + ((r + 1) * (part_size + 1) - 1)])\n\u001b[1;32m 1799\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1800\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"array split does not result in an equal division\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1801\u001b[0m \u001b[0mstarts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mends\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mndim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mary\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mary\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1802\u001b[0m \u001b[0m_subval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0msubvals\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
464
- "\u001b[0;31mValueError\u001b[0m: array split does not result in an equal division"
465
- ]
466
- }
467
- ]
468
- }
469
- ]
470
- }