Vivek commited on
Commit
899554e
1 Parent(s): fde4e6c

add colab notebook

Browse files
Files changed (2) hide show
  1. GPT2(error).ipynb +1074 -0
  2. Untitled330.ipynb +470 -0
GPT2(error).ipynb ADDED
@@ -0,0 +1,1074 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "GPT2(error).ipynb",
7
+ "provenance": [],
8
+ "collapsed_sections": []
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ },
17
+ "widgets": {
18
+ "application/vnd.jupyter.widget-state+json": {
19
+ "1b266a2c1cf646a392a46e39586282b3": {
20
+ "model_module": "@jupyter-widgets/controls",
21
+ "model_name": "HBoxModel",
22
+ "state": {
23
+ "_view_name": "HBoxView",
24
+ "_dom_classes": [],
25
+ "_model_name": "HBoxModel",
26
+ "_view_module": "@jupyter-widgets/controls",
27
+ "_model_module_version": "1.5.0",
28
+ "_view_count": null,
29
+ "_view_module_version": "1.5.0",
30
+ "box_style": "",
31
+ "layout": "IPY_MODEL_8ecfcf14981c4d82b5d9d3839a496f0b",
32
+ "_model_module": "@jupyter-widgets/controls",
33
+ "children": [
34
+ "IPY_MODEL_16b07572ac0d46798b2c2a292c3f9143",
35
+ "IPY_MODEL_cf412ff73fc647908154abc9b2847f38"
36
+ ]
37
+ }
38
+ },
39
+ "8ecfcf14981c4d82b5d9d3839a496f0b": {
40
+ "model_module": "@jupyter-widgets/base",
41
+ "model_name": "LayoutModel",
42
+ "state": {
43
+ "_view_name": "LayoutView",
44
+ "grid_template_rows": null,
45
+ "right": null,
46
+ "justify_content": null,
47
+ "_view_module": "@jupyter-widgets/base",
48
+ "overflow": null,
49
+ "_model_module_version": "1.2.0",
50
+ "_view_count": null,
51
+ "flex_flow": null,
52
+ "width": null,
53
+ "min_width": null,
54
+ "border": null,
55
+ "align_items": null,
56
+ "bottom": null,
57
+ "_model_module": "@jupyter-widgets/base",
58
+ "top": null,
59
+ "grid_column": null,
60
+ "overflow_y": null,
61
+ "overflow_x": null,
62
+ "grid_auto_flow": null,
63
+ "grid_area": null,
64
+ "grid_template_columns": null,
65
+ "flex": null,
66
+ "_model_name": "LayoutModel",
67
+ "justify_items": null,
68
+ "grid_row": null,
69
+ "max_height": null,
70
+ "align_content": null,
71
+ "visibility": null,
72
+ "align_self": null,
73
+ "height": null,
74
+ "min_height": null,
75
+ "padding": null,
76
+ "grid_auto_rows": null,
77
+ "grid_gap": null,
78
+ "max_width": null,
79
+ "order": null,
80
+ "_view_module_version": "1.2.0",
81
+ "grid_template_areas": null,
82
+ "object_position": null,
83
+ "object_fit": null,
84
+ "grid_auto_columns": null,
85
+ "margin": null,
86
+ "display": null,
87
+ "left": null
88
+ }
89
+ },
90
+ "16b07572ac0d46798b2c2a292c3f9143": {
91
+ "model_module": "@jupyter-widgets/controls",
92
+ "model_name": "FloatProgressModel",
93
+ "state": {
94
+ "_view_name": "ProgressView",
95
+ "style": "IPY_MODEL_6ebc21286ae843e5b9ba4df8f4cebfe0",
96
+ "_dom_classes": [],
97
+ "description": "Downloading: 100%",
98
+ "_model_name": "FloatProgressModel",
99
+ "bar_style": "success",
100
+ "max": 1042301,
101
+ "_view_module": "@jupyter-widgets/controls",
102
+ "_model_module_version": "1.5.0",
103
+ "value": 1042301,
104
+ "_view_count": null,
105
+ "_view_module_version": "1.5.0",
106
+ "orientation": "horizontal",
107
+ "min": 0,
108
+ "description_tooltip": null,
109
+ "_model_module": "@jupyter-widgets/controls",
110
+ "layout": "IPY_MODEL_48246be80e82429da2d48f9d4a1aaf0a"
111
+ }
112
+ },
113
+ "cf412ff73fc647908154abc9b2847f38": {
114
+ "model_module": "@jupyter-widgets/controls",
115
+ "model_name": "HTMLModel",
116
+ "state": {
117
+ "_view_name": "HTMLView",
118
+ "style": "IPY_MODEL_5754900e885d4f509ede058b186fcab6",
119
+ "_dom_classes": [],
120
+ "description": "",
121
+ "_model_name": "HTMLModel",
122
+ "placeholder": "​",
123
+ "_view_module": "@jupyter-widgets/controls",
124
+ "_model_module_version": "1.5.0",
125
+ "value": " 1.04M/1.04M [00:06<00:00, 154kB/s]",
126
+ "_view_count": null,
127
+ "_view_module_version": "1.5.0",
128
+ "description_tooltip": null,
129
+ "_model_module": "@jupyter-widgets/controls",
130
+ "layout": "IPY_MODEL_d0434381119c46489e17fcbccd9755ea"
131
+ }
132
+ },
133
+ "6ebc21286ae843e5b9ba4df8f4cebfe0": {
134
+ "model_module": "@jupyter-widgets/controls",
135
+ "model_name": "ProgressStyleModel",
136
+ "state": {
137
+ "_view_name": "StyleView",
138
+ "_model_name": "ProgressStyleModel",
139
+ "description_width": "initial",
140
+ "_view_module": "@jupyter-widgets/base",
141
+ "_model_module_version": "1.5.0",
142
+ "_view_count": null,
143
+ "_view_module_version": "1.2.0",
144
+ "bar_color": null,
145
+ "_model_module": "@jupyter-widgets/controls"
146
+ }
147
+ },
148
+ "48246be80e82429da2d48f9d4a1aaf0a": {
149
+ "model_module": "@jupyter-widgets/base",
150
+ "model_name": "LayoutModel",
151
+ "state": {
152
+ "_view_name": "LayoutView",
153
+ "grid_template_rows": null,
154
+ "right": null,
155
+ "justify_content": null,
156
+ "_view_module": "@jupyter-widgets/base",
157
+ "overflow": null,
158
+ "_model_module_version": "1.2.0",
159
+ "_view_count": null,
160
+ "flex_flow": null,
161
+ "width": null,
162
+ "min_width": null,
163
+ "border": null,
164
+ "align_items": null,
165
+ "bottom": null,
166
+ "_model_module": "@jupyter-widgets/base",
167
+ "top": null,
168
+ "grid_column": null,
169
+ "overflow_y": null,
170
+ "overflow_x": null,
171
+ "grid_auto_flow": null,
172
+ "grid_area": null,
173
+ "grid_template_columns": null,
174
+ "flex": null,
175
+ "_model_name": "LayoutModel",
176
+ "justify_items": null,
177
+ "grid_row": null,
178
+ "max_height": null,
179
+ "align_content": null,
180
+ "visibility": null,
181
+ "align_self": null,
182
+ "height": null,
183
+ "min_height": null,
184
+ "padding": null,
185
+ "grid_auto_rows": null,
186
+ "grid_gap": null,
187
+ "max_width": null,
188
+ "order": null,
189
+ "_view_module_version": "1.2.0",
190
+ "grid_template_areas": null,
191
+ "object_position": null,
192
+ "object_fit": null,
193
+ "grid_auto_columns": null,
194
+ "margin": null,
195
+ "display": null,
196
+ "left": null
197
+ }
198
+ },
199
+ "5754900e885d4f509ede058b186fcab6": {
200
+ "model_module": "@jupyter-widgets/controls",
201
+ "model_name": "DescriptionStyleModel",
202
+ "state": {
203
+ "_view_name": "StyleView",
204
+ "_model_name": "DescriptionStyleModel",
205
+ "description_width": "",
206
+ "_view_module": "@jupyter-widgets/base",
207
+ "_model_module_version": "1.5.0",
208
+ "_view_count": null,
209
+ "_view_module_version": "1.2.0",
210
+ "_model_module": "@jupyter-widgets/controls"
211
+ }
212
+ },
213
+ "d0434381119c46489e17fcbccd9755ea": {
214
+ "model_module": "@jupyter-widgets/base",
215
+ "model_name": "LayoutModel",
216
+ "state": {
217
+ "_view_name": "LayoutView",
218
+ "grid_template_rows": null,
219
+ "right": null,
220
+ "justify_content": null,
221
+ "_view_module": "@jupyter-widgets/base",
222
+ "overflow": null,
223
+ "_model_module_version": "1.2.0",
224
+ "_view_count": null,
225
+ "flex_flow": null,
226
+ "width": null,
227
+ "min_width": null,
228
+ "border": null,
229
+ "align_items": null,
230
+ "bottom": null,
231
+ "_model_module": "@jupyter-widgets/base",
232
+ "top": null,
233
+ "grid_column": null,
234
+ "overflow_y": null,
235
+ "overflow_x": null,
236
+ "grid_auto_flow": null,
237
+ "grid_area": null,
238
+ "grid_template_columns": null,
239
+ "flex": null,
240
+ "_model_name": "LayoutModel",
241
+ "justify_items": null,
242
+ "grid_row": null,
243
+ "max_height": null,
244
+ "align_content": null,
245
+ "visibility": null,
246
+ "align_self": null,
247
+ "height": null,
248
+ "min_height": null,
249
+ "padding": null,
250
+ "grid_auto_rows": null,
251
+ "grid_gap": null,
252
+ "max_width": null,
253
+ "order": null,
254
+ "_view_module_version": "1.2.0",
255
+ "grid_template_areas": null,
256
+ "object_position": null,
257
+ "object_fit": null,
258
+ "grid_auto_columns": null,
259
+ "margin": null,
260
+ "display": null,
261
+ "left": null
262
+ }
263
+ },
264
+ "73c4b8bc05f64477aa03d767f4483795": {
265
+ "model_module": "@jupyter-widgets/controls",
266
+ "model_name": "HBoxModel",
267
+ "state": {
268
+ "_view_name": "HBoxView",
269
+ "_dom_classes": [],
270
+ "_model_name": "HBoxModel",
271
+ "_view_module": "@jupyter-widgets/controls",
272
+ "_model_module_version": "1.5.0",
273
+ "_view_count": null,
274
+ "_view_module_version": "1.5.0",
275
+ "box_style": "",
276
+ "layout": "IPY_MODEL_6123827ad5964b4b8a17aaca618b4768",
277
+ "_model_module": "@jupyter-widgets/controls",
278
+ "children": [
279
+ "IPY_MODEL_5327d425e74d4a599214282b9b70d58b",
280
+ "IPY_MODEL_974490d04f18407f9f5a5785b2802c0a"
281
+ ]
282
+ }
283
+ },
284
+ "6123827ad5964b4b8a17aaca618b4768": {
285
+ "model_module": "@jupyter-widgets/base",
286
+ "model_name": "LayoutModel",
287
+ "state": {
288
+ "_view_name": "LayoutView",
289
+ "grid_template_rows": null,
290
+ "right": null,
291
+ "justify_content": null,
292
+ "_view_module": "@jupyter-widgets/base",
293
+ "overflow": null,
294
+ "_model_module_version": "1.2.0",
295
+ "_view_count": null,
296
+ "flex_flow": null,
297
+ "width": null,
298
+ "min_width": null,
299
+ "border": null,
300
+ "align_items": null,
301
+ "bottom": null,
302
+ "_model_module": "@jupyter-widgets/base",
303
+ "top": null,
304
+ "grid_column": null,
305
+ "overflow_y": null,
306
+ "overflow_x": null,
307
+ "grid_auto_flow": null,
308
+ "grid_area": null,
309
+ "grid_template_columns": null,
310
+ "flex": null,
311
+ "_model_name": "LayoutModel",
312
+ "justify_items": null,
313
+ "grid_row": null,
314
+ "max_height": null,
315
+ "align_content": null,
316
+ "visibility": null,
317
+ "align_self": null,
318
+ "height": null,
319
+ "min_height": null,
320
+ "padding": null,
321
+ "grid_auto_rows": null,
322
+ "grid_gap": null,
323
+ "max_width": null,
324
+ "order": null,
325
+ "_view_module_version": "1.2.0",
326
+ "grid_template_areas": null,
327
+ "object_position": null,
328
+ "object_fit": null,
329
+ "grid_auto_columns": null,
330
+ "margin": null,
331
+ "display": null,
332
+ "left": null
333
+ }
334
+ },
335
+ "5327d425e74d4a599214282b9b70d58b": {
336
+ "model_module": "@jupyter-widgets/controls",
337
+ "model_name": "FloatProgressModel",
338
+ "state": {
339
+ "_view_name": "ProgressView",
340
+ "style": "IPY_MODEL_c3cc1723c39a4d74b2ab83bd23b5fcce",
341
+ "_dom_classes": [],
342
+ "description": "Downloading: 100%",
343
+ "_model_name": "FloatProgressModel",
344
+ "bar_style": "success",
345
+ "max": 456318,
346
+ "_view_module": "@jupyter-widgets/controls",
347
+ "_model_module_version": "1.5.0",
348
+ "value": 456318,
349
+ "_view_count": null,
350
+ "_view_module_version": "1.5.0",
351
+ "orientation": "horizontal",
352
+ "min": 0,
353
+ "description_tooltip": null,
354
+ "_model_module": "@jupyter-widgets/controls",
355
+ "layout": "IPY_MODEL_391d59bf8d2845f88a83dc25c7cf89f3"
356
+ }
357
+ },
358
+ "974490d04f18407f9f5a5785b2802c0a": {
359
+ "model_module": "@jupyter-widgets/controls",
360
+ "model_name": "HTMLModel",
361
+ "state": {
362
+ "_view_name": "HTMLView",
363
+ "style": "IPY_MODEL_d60fa9fe71444784b78bdfba6ed6a9e1",
364
+ "_dom_classes": [],
365
+ "description": "",
366
+ "_model_name": "HTMLModel",
367
+ "placeholder": "​",
368
+ "_view_module": "@jupyter-widgets/controls",
369
+ "_model_module_version": "1.5.0",
370
+ "value": " 456k/456k [00:04<00:00, 96.1kB/s]",
371
+ "_view_count": null,
372
+ "_view_module_version": "1.5.0",
373
+ "description_tooltip": null,
374
+ "_model_module": "@jupyter-widgets/controls",
375
+ "layout": "IPY_MODEL_41a3b55e5e264b85ada9558e5777790f"
376
+ }
377
+ },
378
+ "c3cc1723c39a4d74b2ab83bd23b5fcce": {
379
+ "model_module": "@jupyter-widgets/controls",
380
+ "model_name": "ProgressStyleModel",
381
+ "state": {
382
+ "_view_name": "StyleView",
383
+ "_model_name": "ProgressStyleModel",
384
+ "description_width": "initial",
385
+ "_view_module": "@jupyter-widgets/base",
386
+ "_model_module_version": "1.5.0",
387
+ "_view_count": null,
388
+ "_view_module_version": "1.2.0",
389
+ "bar_color": null,
390
+ "_model_module": "@jupyter-widgets/controls"
391
+ }
392
+ },
393
+ "391d59bf8d2845f88a83dc25c7cf89f3": {
394
+ "model_module": "@jupyter-widgets/base",
395
+ "model_name": "LayoutModel",
396
+ "state": {
397
+ "_view_name": "LayoutView",
398
+ "grid_template_rows": null,
399
+ "right": null,
400
+ "justify_content": null,
401
+ "_view_module": "@jupyter-widgets/base",
402
+ "overflow": null,
403
+ "_model_module_version": "1.2.0",
404
+ "_view_count": null,
405
+ "flex_flow": null,
406
+ "width": null,
407
+ "min_width": null,
408
+ "border": null,
409
+ "align_items": null,
410
+ "bottom": null,
411
+ "_model_module": "@jupyter-widgets/base",
412
+ "top": null,
413
+ "grid_column": null,
414
+ "overflow_y": null,
415
+ "overflow_x": null,
416
+ "grid_auto_flow": null,
417
+ "grid_area": null,
418
+ "grid_template_columns": null,
419
+ "flex": null,
420
+ "_model_name": "LayoutModel",
421
+ "justify_items": null,
422
+ "grid_row": null,
423
+ "max_height": null,
424
+ "align_content": null,
425
+ "visibility": null,
426
+ "align_self": null,
427
+ "height": null,
428
+ "min_height": null,
429
+ "padding": null,
430
+ "grid_auto_rows": null,
431
+ "grid_gap": null,
432
+ "max_width": null,
433
+ "order": null,
434
+ "_view_module_version": "1.2.0",
435
+ "grid_template_areas": null,
436
+ "object_position": null,
437
+ "object_fit": null,
438
+ "grid_auto_columns": null,
439
+ "margin": null,
440
+ "display": null,
441
+ "left": null
442
+ }
443
+ },
444
+ "d60fa9fe71444784b78bdfba6ed6a9e1": {
445
+ "model_module": "@jupyter-widgets/controls",
446
+ "model_name": "DescriptionStyleModel",
447
+ "state": {
448
+ "_view_name": "StyleView",
449
+ "_model_name": "DescriptionStyleModel",
450
+ "description_width": "",
451
+ "_view_module": "@jupyter-widgets/base",
452
+ "_model_module_version": "1.5.0",
453
+ "_view_count": null,
454
+ "_view_module_version": "1.2.0",
455
+ "_model_module": "@jupyter-widgets/controls"
456
+ }
457
+ },
458
+ "41a3b55e5e264b85ada9558e5777790f": {
459
+ "model_module": "@jupyter-widgets/base",
460
+ "model_name": "LayoutModel",
461
+ "state": {
462
+ "_view_name": "LayoutView",
463
+ "grid_template_rows": null,
464
+ "right": null,
465
+ "justify_content": null,
466
+ "_view_module": "@jupyter-widgets/base",
467
+ "overflow": null,
468
+ "_model_module_version": "1.2.0",
469
+ "_view_count": null,
470
+ "flex_flow": null,
471
+ "width": null,
472
+ "min_width": null,
473
+ "border": null,
474
+ "align_items": null,
475
+ "bottom": null,
476
+ "_model_module": "@jupyter-widgets/base",
477
+ "top": null,
478
+ "grid_column": null,
479
+ "overflow_y": null,
480
+ "overflow_x": null,
481
+ "grid_auto_flow": null,
482
+ "grid_area": null,
483
+ "grid_template_columns": null,
484
+ "flex": null,
485
+ "_model_name": "LayoutModel",
486
+ "justify_items": null,
487
+ "grid_row": null,
488
+ "max_height": null,
489
+ "align_content": null,
490
+ "visibility": null,
491
+ "align_self": null,
492
+ "height": null,
493
+ "min_height": null,
494
+ "padding": null,
495
+ "grid_auto_rows": null,
496
+ "grid_gap": null,
497
+ "max_width": null,
498
+ "order": null,
499
+ "_view_module_version": "1.2.0",
500
+ "grid_template_areas": null,
501
+ "object_position": null,
502
+ "object_fit": null,
503
+ "grid_auto_columns": null,
504
+ "margin": null,
505
+ "display": null,
506
+ "left": null
507
+ }
508
+ },
509
+ "aa4d6e2e9ac44e9bb40b7daccc91ee83": {
510
+ "model_module": "@jupyter-widgets/controls",
511
+ "model_name": "HBoxModel",
512
+ "state": {
513
+ "_view_name": "HBoxView",
514
+ "_dom_classes": [],
515
+ "_model_name": "HBoxModel",
516
+ "_view_module": "@jupyter-widgets/controls",
517
+ "_model_module_version": "1.5.0",
518
+ "_view_count": null,
519
+ "_view_module_version": "1.5.0",
520
+ "box_style": "",
521
+ "layout": "IPY_MODEL_c3b054972a6145d1ad03ca938a7ade9c",
522
+ "_model_module": "@jupyter-widgets/controls",
523
+ "children": [
524
+ "IPY_MODEL_644f4a69db534dd4a11172e5d010e8fe",
525
+ "IPY_MODEL_0e19c2d5b060490399efbfcda773e9ba"
526
+ ]
527
+ }
528
+ },
529
+ "c3b054972a6145d1ad03ca938a7ade9c": {
530
+ "model_module": "@jupyter-widgets/base",
531
+ "model_name": "LayoutModel",
532
+ "state": {
533
+ "_view_name": "LayoutView",
534
+ "grid_template_rows": null,
535
+ "right": null,
536
+ "justify_content": null,
537
+ "_view_module": "@jupyter-widgets/base",
538
+ "overflow": null,
539
+ "_model_module_version": "1.2.0",
540
+ "_view_count": null,
541
+ "flex_flow": null,
542
+ "width": null,
543
+ "min_width": null,
544
+ "border": null,
545
+ "align_items": null,
546
+ "bottom": null,
547
+ "_model_module": "@jupyter-widgets/base",
548
+ "top": null,
549
+ "grid_column": null,
550
+ "overflow_y": null,
551
+ "overflow_x": null,
552
+ "grid_auto_flow": null,
553
+ "grid_area": null,
554
+ "grid_template_columns": null,
555
+ "flex": null,
556
+ "_model_name": "LayoutModel",
557
+ "justify_items": null,
558
+ "grid_row": null,
559
+ "max_height": null,
560
+ "align_content": null,
561
+ "visibility": null,
562
+ "align_self": null,
563
+ "height": null,
564
+ "min_height": null,
565
+ "padding": null,
566
+ "grid_auto_rows": null,
567
+ "grid_gap": null,
568
+ "max_width": null,
569
+ "order": null,
570
+ "_view_module_version": "1.2.0",
571
+ "grid_template_areas": null,
572
+ "object_position": null,
573
+ "object_fit": null,
574
+ "grid_auto_columns": null,
575
+ "margin": null,
576
+ "display": null,
577
+ "left": null
578
+ }
579
+ },
580
+ "644f4a69db534dd4a11172e5d010e8fe": {
581
+ "model_module": "@jupyter-widgets/controls",
582
+ "model_name": "FloatProgressModel",
583
+ "state": {
584
+ "_view_name": "ProgressView",
585
+ "style": "IPY_MODEL_109479db406d4085acff84904cdac4ef",
586
+ "_dom_classes": [],
587
+ "description": "Downloading: 100%",
588
+ "_model_name": "FloatProgressModel",
589
+ "bar_style": "success",
590
+ "max": 1355256,
591
+ "_view_module": "@jupyter-widgets/controls",
592
+ "_model_module_version": "1.5.0",
593
+ "value": 1355256,
594
+ "_view_count": null,
595
+ "_view_module_version": "1.5.0",
596
+ "orientation": "horizontal",
597
+ "min": 0,
598
+ "description_tooltip": null,
599
+ "_model_module": "@jupyter-widgets/controls",
600
+ "layout": "IPY_MODEL_1fb7fd7b44bf4b3a9e949f025487ff47"
601
+ }
602
+ },
603
+ "0e19c2d5b060490399efbfcda773e9ba": {
604
+ "model_module": "@jupyter-widgets/controls",
605
+ "model_name": "HTMLModel",
606
+ "state": {
607
+ "_view_name": "HTMLView",
608
+ "style": "IPY_MODEL_01c30370e3ff4b5caf9a3369841ad597",
609
+ "_dom_classes": [],
610
+ "description": "",
611
+ "_model_name": "HTMLModel",
612
+ "placeholder": "​",
613
+ "_view_module": "@jupyter-widgets/controls",
614
+ "_model_module_version": "1.5.0",
615
+ "value": " 1.36M/1.36M [00:00<00:00, 1.73MB/s]",
616
+ "_view_count": null,
617
+ "_view_module_version": "1.5.0",
618
+ "description_tooltip": null,
619
+ "_model_module": "@jupyter-widgets/controls",
620
+ "layout": "IPY_MODEL_d19eedc2acce48d4be7189b422b5fcb9"
621
+ }
622
+ },
623
+ "109479db406d4085acff84904cdac4ef": {
624
+ "model_module": "@jupyter-widgets/controls",
625
+ "model_name": "ProgressStyleModel",
626
+ "state": {
627
+ "_view_name": "StyleView",
628
+ "_model_name": "ProgressStyleModel",
629
+ "description_width": "initial",
630
+ "_view_module": "@jupyter-widgets/base",
631
+ "_model_module_version": "1.5.0",
632
+ "_view_count": null,
633
+ "_view_module_version": "1.2.0",
634
+ "bar_color": null,
635
+ "_model_module": "@jupyter-widgets/controls"
636
+ }
637
+ },
638
+ "1fb7fd7b44bf4b3a9e949f025487ff47": {
639
+ "model_module": "@jupyter-widgets/base",
640
+ "model_name": "LayoutModel",
641
+ "state": {
642
+ "_view_name": "LayoutView",
643
+ "grid_template_rows": null,
644
+ "right": null,
645
+ "justify_content": null,
646
+ "_view_module": "@jupyter-widgets/base",
647
+ "overflow": null,
648
+ "_model_module_version": "1.2.0",
649
+ "_view_count": null,
650
+ "flex_flow": null,
651
+ "width": null,
652
+ "min_width": null,
653
+ "border": null,
654
+ "align_items": null,
655
+ "bottom": null,
656
+ "_model_module": "@jupyter-widgets/base",
657
+ "top": null,
658
+ "grid_column": null,
659
+ "overflow_y": null,
660
+ "overflow_x": null,
661
+ "grid_auto_flow": null,
662
+ "grid_area": null,
663
+ "grid_template_columns": null,
664
+ "flex": null,
665
+ "_model_name": "LayoutModel",
666
+ "justify_items": null,
667
+ "grid_row": null,
668
+ "max_height": null,
669
+ "align_content": null,
670
+ "visibility": null,
671
+ "align_self": null,
672
+ "height": null,
673
+ "min_height": null,
674
+ "padding": null,
675
+ "grid_auto_rows": null,
676
+ "grid_gap": null,
677
+ "max_width": null,
678
+ "order": null,
679
+ "_view_module_version": "1.2.0",
680
+ "grid_template_areas": null,
681
+ "object_position": null,
682
+ "object_fit": null,
683
+ "grid_auto_columns": null,
684
+ "margin": null,
685
+ "display": null,
686
+ "left": null
687
+ }
688
+ },
689
+ "01c30370e3ff4b5caf9a3369841ad597": {
690
+ "model_module": "@jupyter-widgets/controls",
691
+ "model_name": "DescriptionStyleModel",
692
+ "state": {
693
+ "_view_name": "StyleView",
694
+ "_model_name": "DescriptionStyleModel",
695
+ "description_width": "",
696
+ "_view_module": "@jupyter-widgets/base",
697
+ "_model_module_version": "1.5.0",
698
+ "_view_count": null,
699
+ "_view_module_version": "1.2.0",
700
+ "_model_module": "@jupyter-widgets/controls"
701
+ }
702
+ },
703
+ "d19eedc2acce48d4be7189b422b5fcb9": {
704
+ "model_module": "@jupyter-widgets/base",
705
+ "model_name": "LayoutModel",
706
+ "state": {
707
+ "_view_name": "LayoutView",
708
+ "grid_template_rows": null,
709
+ "right": null,
710
+ "justify_content": null,
711
+ "_view_module": "@jupyter-widgets/base",
712
+ "overflow": null,
713
+ "_model_module_version": "1.2.0",
714
+ "_view_count": null,
715
+ "flex_flow": null,
716
+ "width": null,
717
+ "min_width": null,
718
+ "border": null,
719
+ "align_items": null,
720
+ "bottom": null,
721
+ "_model_module": "@jupyter-widgets/base",
722
+ "top": null,
723
+ "grid_column": null,
724
+ "overflow_y": null,
725
+ "overflow_x": null,
726
+ "grid_auto_flow": null,
727
+ "grid_area": null,
728
+ "grid_template_columns": null,
729
+ "flex": null,
730
+ "_model_name": "LayoutModel",
731
+ "justify_items": null,
732
+ "grid_row": null,
733
+ "max_height": null,
734
+ "align_content": null,
735
+ "visibility": null,
736
+ "align_self": null,
737
+ "height": null,
738
+ "min_height": null,
739
+ "padding": null,
740
+ "grid_auto_rows": null,
741
+ "grid_gap": null,
742
+ "max_width": null,
743
+ "order": null,
744
+ "_view_module_version": "1.2.0",
745
+ "grid_template_areas": null,
746
+ "object_position": null,
747
+ "object_fit": null,
748
+ "grid_auto_columns": null,
749
+ "margin": null,
750
+ "display": null,
751
+ "left": null
752
+ }
753
+ }
754
+ }
755
+ }
756
+ },
757
+ "cells": [
758
+ {
759
+ "cell_type": "code",
760
+ "metadata": {
761
+ "id": "hYCVkKKAwSjV"
762
+ },
763
+ "source": [
764
+ "%%capture\n",
765
+ "!pip install transformers\n",
766
+ "!pip install datasets\n",
767
+ "!pip install --upgrade git+https://github.com/google/flax.git"
768
+ ],
769
+ "execution_count": 1,
770
+ "outputs": []
771
+ },
772
+ {
773
+ "cell_type": "code",
774
+ "metadata": {
775
+ "id": "2gcm5rxByOXO",
776
+ "colab": {
777
+ "base_uri": "https://localhost:8080/",
778
+ "height": 164,
779
+ "referenced_widgets": [
780
+ "1b266a2c1cf646a392a46e39586282b3",
781
+ "8ecfcf14981c4d82b5d9d3839a496f0b",
782
+ "16b07572ac0d46798b2c2a292c3f9143",
783
+ "cf412ff73fc647908154abc9b2847f38",
784
+ "6ebc21286ae843e5b9ba4df8f4cebfe0",
785
+ "48246be80e82429da2d48f9d4a1aaf0a",
786
+ "5754900e885d4f509ede058b186fcab6",
787
+ "d0434381119c46489e17fcbccd9755ea",
788
+ "73c4b8bc05f64477aa03d767f4483795",
789
+ "6123827ad5964b4b8a17aaca618b4768",
790
+ "5327d425e74d4a599214282b9b70d58b",
791
+ "974490d04f18407f9f5a5785b2802c0a",
792
+ "c3cc1723c39a4d74b2ab83bd23b5fcce",
793
+ "391d59bf8d2845f88a83dc25c7cf89f3",
794
+ "d60fa9fe71444784b78bdfba6ed6a9e1",
795
+ "41a3b55e5e264b85ada9558e5777790f",
796
+ "aa4d6e2e9ac44e9bb40b7daccc91ee83",
797
+ "c3b054972a6145d1ad03ca938a7ade9c",
798
+ "644f4a69db534dd4a11172e5d010e8fe",
799
+ "0e19c2d5b060490399efbfcda773e9ba",
800
+ "109479db406d4085acff84904cdac4ef",
801
+ "1fb7fd7b44bf4b3a9e949f025487ff47",
802
+ "01c30370e3ff4b5caf9a3369841ad597",
803
+ "d19eedc2acce48d4be7189b422b5fcb9"
804
+ ]
805
+ },
806
+ "outputId": "5814323f-d04d-408c-e833-8522806ea73b"
807
+ },
808
+ "source": [
809
+ "import jax\n",
810
+ "from transformers.modeling_flax_utils import FlaxPreTrainedModel\n",
811
+ "import flax.linen as nn\n",
812
+ "import jax.numpy as jnp\n",
813
+ "from transformers import GPT2Config\n",
814
+ "from transformers import FlaxGPT2PreTrainedModel\n",
815
+ "from transformers import FlaxGPT2Model\n",
816
+ "from transformers import GPT2Tokenizer\n",
817
+ "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\",pad_token='<|endoftext|>')"
818
+ ],
819
+ "execution_count": 2,
820
+ "outputs": [
821
+ {
822
+ "output_type": "display_data",
823
+ "data": {
824
+ "application/vnd.jupyter.widget-view+json": {
825
+ "model_id": "1b266a2c1cf646a392a46e39586282b3",
826
+ "version_minor": 0,
827
+ "version_major": 2
828
+ },
829
+ "text/plain": [
830
+ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1042301.0, style=ProgressStyle(descript…"
831
+ ]
832
+ },
833
+ "metadata": {
834
+ "tags": []
835
+ }
836
+ },
837
+ {
838
+ "output_type": "stream",
839
+ "text": [
840
+ "\n"
841
+ ],
842
+ "name": "stdout"
843
+ },
844
+ {
845
+ "output_type": "display_data",
846
+ "data": {
847
+ "application/vnd.jupyter.widget-view+json": {
848
+ "model_id": "73c4b8bc05f64477aa03d767f4483795",
849
+ "version_minor": 0,
850
+ "version_major": 2
851
+ },
852
+ "text/plain": [
853
+ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=456318.0, style=ProgressStyle(descripti…"
854
+ ]
855
+ },
856
+ "metadata": {
857
+ "tags": []
858
+ }
859
+ },
860
+ {
861
+ "output_type": "stream",
862
+ "text": [
863
+ "\n"
864
+ ],
865
+ "name": "stdout"
866
+ },
867
+ {
868
+ "output_type": "display_data",
869
+ "data": {
870
+ "application/vnd.jupyter.widget-view+json": {
871
+ "model_id": "aa4d6e2e9ac44e9bb40b7daccc91ee83",
872
+ "version_minor": 0,
873
+ "version_major": 2
874
+ },
875
+ "text/plain": [
876
+ "HBox(children=(FloatProgress(value=0.0, description='Downloading', max=1355256.0, style=ProgressStyle(descript…"
877
+ ]
878
+ },
879
+ "metadata": {
880
+ "tags": []
881
+ }
882
+ },
883
+ {
884
+ "output_type": "stream",
885
+ "text": [
886
+ "\n"
887
+ ],
888
+ "name": "stdout"
889
+ }
890
+ ]
891
+ },
892
+ {
893
+ "cell_type": "code",
894
+ "metadata": {
895
+ "id": "GDokS6VEJI6C"
896
+ },
897
+ "source": [
898
+ "#inputs = tokenizer([\"JAX/Flax is amazing \",\"tensorflow is also good\"],[\"pytorch is better\",\"keras is the best\"],return_tensors='jax',padding='max_length',max_length=30)"
899
+ ],
900
+ "execution_count": 3,
901
+ "outputs": []
902
+ },
903
+ {
904
+ "cell_type": "code",
905
+ "metadata": {
906
+ "id": "hWiMk1TzyYim"
907
+ },
908
+ "source": [
909
+ "class FlaxGPT2ForMultipleChoiceModule(nn.Module):\n",
910
+ " config:GPT2Config\n",
911
+ " dtype: jnp.dtype = jnp.float32\n",
912
+ " def setup(self):\n",
913
+ " self.gpt2 = FlaxGPT2Model(config=self.config, dtype=self.dtype)\n",
914
+ " self.dropout = nn.Dropout(rate=0.2)\n",
915
+ " self.classifier = nn.Dense(4, dtype=self.dtype)\n",
916
+ "\n",
917
+ " def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):\n",
918
+ " batch_size = input_ids.shape[0]\n",
919
+ "\n",
920
+ " rng=jax.random.PRNGKey(0)\n",
921
+ " _, dropout_rng = jax.random.split(rng)\n",
922
+ "\n",
923
+ " outputs=self.gpt2(input_ids, attention_mask,position_ids,return_dict=return_dict)\n",
924
+ " \n",
925
+ "\n",
926
+ " hidden_states = outputs[0]\n",
927
+ "\n",
928
+ " \n",
929
+ " hidden_states= jnp.mean(hidden_states, axis=1)\n",
930
+ "\n",
931
+ " print(hidden_states.shape)\n",
932
+ " \n",
933
+ " hidden_states=hidden_states.reshape(batch_size,-1) #(32,8,768)->(32,8*768)\n",
934
+ "\n",
935
+ " dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng)\n",
936
+ "\n",
937
+ " print(dropout_output.shape)\n",
938
+ "\n",
939
+ " logits = self.classifier(dropout_output)\n",
940
+ " reshaped_logits = logits.reshape(-1, 4) #(32,4)\n",
941
+ " if not return_dict:\n",
942
+ " return (reshaped_logits,) + outputs[2:]\n",
943
+ " return reshaped_logits"
944
+ ],
945
+ "execution_count": 7,
946
+ "outputs": []
947
+ },
948
+ {
949
+ "cell_type": "code",
950
+ "metadata": {
951
+ "id": "u1j00Ck255BC"
952
+ },
953
+ "source": [
954
+ "class FlaxGPT2ForMultipleChoice(FlaxGPT2PreTrainedModel):\n",
955
+ " module_class = FlaxGPT2ForMultipleChoiceModule"
956
+ ],
957
+ "execution_count": 8,
958
+ "outputs": []
959
+ },
960
+ {
961
+ "cell_type": "code",
962
+ "metadata": {
963
+ "id": "h2MrRgKTRxZO",
964
+ "colab": {
965
+ "base_uri": "https://localhost:8080/"
966
+ },
967
+ "outputId": "5a0fcc68-ca39-4df0-c854-734125d65f53"
968
+ },
969
+ "source": [
970
+ "model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2') # getting warning"
971
+ ],
972
+ "execution_count": 9,
973
+ "outputs": [
974
+ {
975
+ "output_type": "stream",
976
+ "text": [
977
+ "(1, 768)\n",
978
+ "(1, 768)\n"
979
+ ],
980
+ "name": "stdout"
981
+ },
982
+ {
983
+ "output_type": "stream",
984
+ "text": [
985
+ "Some weights of the model checkpoint at gpt2 were not used when initializing FlaxGPT2ForMultipleChoice: {('h', '1', 'ln_1', 'bias'), ('h', '6', 'ln_1', 'scale'), ('h', '1', 'attn', 'c_proj', 'kernel'), ('h', '11', 'mlp', 'c_fc', 'bias'), ('h', '7', 'ln_1', 'bias'), ('h', '5', 'ln_2', 'bias'), ('h', '10', 'ln_2', 'scale'), ('h', '4', 'mlp', 'c_proj', 'kernel'), ('h', '0', 'mlp', 'c_proj', 'bias'), ('h', '0', 'ln_1', 'bias'), ('h', '0', 'mlp', 'c_fc', 'kernel'), ('wpe', 'embedding'), ('h', '3', 'ln_1', 'scale'), ('h', '2', 'ln_1', 'scale'), ('h', '3', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'ln_1', 'scale'), ('h', '8', 'mlp', 'c_proj', 'kernel'), ('h', '7', 'mlp', 'c_proj', 'kernel'), ('h', '3', 'ln_2', 'bias'), ('h', '9', 'attn', 'c_attn', 'kernel'), ('h', '0', 'mlp', 'c_fc', 'bias'), ('h', '3', 'attn', 'c_proj', 'bias'), ('h', '0', 'ln_1', 'scale'), ('h', '3', 'attn', 'c_attn', 'kernel'), ('h', '0', 'mlp', 'c_proj', 'kernel'), ('h', '5', 'ln_1', 'bias'), ('h', '7', 'attn', 'c_attn', 'bias'), ('h', '1', 'ln_2', 'bias'), ('h', '11', 'ln_2', 'scale'), ('h', '7', 'ln_2', 'bias'), ('h', '9', 'attn', 'c_proj', 'kernel'), ('h', '0', 'ln_2', 'bias'), ('h', '2', 'ln_2', 'scale'), ('h', '11', 'attn', 'c_attn', 'kernel'), ('h', '8', 'attn', 'c_proj', 'kernel'), ('h', '4', 'attn', 'c_attn', 'kernel'), ('h', '5', 'ln_1', 'scale'), ('h', '4', 'ln_1', 'bias'), ('h', '8', 'ln_2', 'bias'), ('h', '1', 'mlp', 'c_fc', 'kernel'), ('h', '9', 'ln_2', 'scale'), ('h', '1', 'mlp', 'c_proj', 'bias'), ('h', '2', 'mlp', 'c_proj', 'kernel'), ('h', '9', 'attn', 'c_proj', 'bias'), ('h', '11', 'ln_2', 'bias'), ('h', '6', 'mlp', 'c_proj', 'bias'), ('h', '3', 'ln_1', 'bias'), ('h', '1', 'attn', 'c_attn', 'kernel'), ('h', '9', 'ln_1', 'scale'), ('h', '10', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_proj', 'kernel'), ('h', '2', 'attn', 'c_proj', 'kernel'), ('h', '0', 'attn', 'c_proj', 'kernel'), ('h', '6', 'attn', 'c_attn', 'kernel'), ('h', '4', 'mlp', 'c_fc', 'bias'), ('h', '3', 'attn', 'c_attn', 'bias'), ('h', '3', 'attn', 'c_proj', 'kernel'), ('h', '11', 'mlp', 'c_proj', 'bias'), ('h', '9', 'attn', 'c_attn', 'bias'), ('h', '7', 'mlp', 'c_proj', 'bias'), ('h', '7', 'mlp', 'c_fc', 'bias'), ('h', '6', 'attn', 'c_attn', 'bias'), ('h', '5', 'mlp', 'c_fc', 'kernel'), ('h', '0', 'attn', 'c_proj', 'bias'), ('h', '2', 'attn', 'c_proj', 'bias'), ('h', '10', 'attn', 'c_attn', 'kernel'), ('h', '10', 'mlp', 'c_proj', 'bias'), ('h', '1', 'attn', 'c_attn', 'bias'), ('h', '11', 'ln_1', 'bias'), ('h', '4', 'ln_2', 'bias'), ('h', '8', 'ln_1', 'bias'), ('h', '11', 'attn', 'c_proj', 'kernel'), ('h', '9', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'ln_2', 'scale'), ('h', '9', 'mlp', 'c_proj', 'kernel'), ('h', '11', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_fc', 'bias'), ('h', '6', 'attn', 'c_proj', 'kernel'), ('h', '0', 'ln_2', 'scale'), ('h', '2', 'ln_2', 'bias'), ('h', '3', 'mlp', 'c_proj', 'bias'), ('h', '5', 'mlp', 'c_proj', 'kernel'), ('h', '8', 'mlp', 'c_fc', 'bias'), ('h', '9', 'mlp', 'c_proj', 'bias'), ('h', '9', 'mlp', 'c_fc', 'bias'), ('h', '8', 'mlp', 'c_fc', 'kernel'), ('h', '9', 'ln_1', 'bias'), ('h', '10', 'ln_1', 'scale'), ('h', '6', 'ln_2', 'bias'), ('h', '2', 'mlp', 'c_fc', 'kernel'), ('h', '4', 'attn', 'c_proj', 'bias'), ('h', '1', 'ln_2', 'scale'), ('h', '5', 'mlp', 'c_fc', 'bias'), ('h', '7', 'mlp', 'c_fc', 'kernel'), ('h', '7', 'attn', 'c_proj', 'bias'), ('h', '5', 'attn', 'c_proj', 'kernel'), ('h', '2', 'mlp', 'c_fc', 'bias'), ('h', '6', 'ln_2', 'scale'), ('h', '11', 'ln_1', 'scale'), ('h', '4', 'mlp', 'c_fc', 'kernel'), ('h', '2', 'ln_1', 'bias'), ('h', '9', 'ln_2', 'bias'), ('h', '11', 'mlp', 'c_fc', 'kernel'), ('h', '1', 'attn', 'c_proj', 'bias'), ('h', '4', 'ln_2', 'scale'), ('h', '8', 'ln_1', 'scale'), ('h', '6', 'attn', 'c_proj', 'bias'), ('h', '5', 'attn', 'c_attn', 'kernel'), ('h', '3', 'ln_2', 'scale'), ('h', '8', 'attn', 'c_attn', 'bias'), ('h', '10', 'mlp', 'c_fc', 'kernel'), ('h', '1', 'ln_1', 'scale'), ('h', '10', 'attn', 'c_proj', 'bias'), ('h', '6', 'ln_1', 'bias'), ('h', '0', 'attn', 'c_attn', 'kernel'), ('wte', 'embedding'), ('h', '6', 'mlp', 'c_fc', 'kernel'), ('h', '4', 'attn', 'c_attn', 'bias'), ('h', '10', 'ln_2', 'bias'), ('h', '8', 'attn', 'c_proj', 'bias'), ('h', '11', 'attn', 'c_proj', 'bias'), ('h', '8', 'attn', 'c_attn', 'kernel'), ('h', '5', 'attn', 'c_attn', 'bias'), ('h', '5', 'ln_2', 'scale'), ('h', '2', 'attn', 'c_attn', 'bias'), ('ln_f', 'scale'), ('h', '7', 'attn', 'c_attn', 'kernel'), ('h', '4', 'ln_1', 'scale'), ('h', '8', 'ln_2', 'scale'), ('h', '11', 'mlp', 'c_proj', 'kernel'), ('h', '5', 'attn', 'c_proj', 'bias'), ('h', '7', 'attn', 'c_proj', 'kernel'), ('h', '8', 'mlp', 'c_proj', 'bias'), ('h', '3', 'mlp', 'c_fc', 'bias'), ('h', '10', 'ln_1', 'bias'), ('h', '2', 'attn', 'c_attn', 'kernel'), ('h', '6', 'mlp', 'c_proj', 'kernel'), ('h', '4', 'attn', 'c_proj', 'kernel'), ('h', '1', 'mlp', 'c_proj', 'kernel'), ('h', '2', 'mlp', 'c_proj', 'bias'), ('h', '1', 'mlp', 'c_fc', 'bias'), ('h', '4', 'mlp', 'c_proj', 'bias'), ('ln_f', 'bias'), ('h', '6', 'mlp', 'c_fc', 'bias'), ('h', '0', 'attn', 'c_attn', 'bias'), ('h', '10', 'attn', 'c_proj', 'kernel'), ('h', '5', 'mlp', 'c_proj', 'bias'), ('h', '3', 'mlp', 'c_proj', 'kernel')}\n",
986
+ "- This IS expected if you are initializing FlaxGPT2ForMultipleChoice from the checkpoint of a model trained on another task or with another architecture (e.g. initializing a BertForSequenceClassification model from a BertForPreTraining model).\n",
987
+ "- This IS NOT expected if you are initializing FlaxGPT2ForMultipleChoice from the checkpoint of a model that you expect to be exactly identical (initializing a BertForSequenceClassification model from a BertForSequenceClassification model).\n",
988
+ "Some weights of FlaxGPT2ForMultipleChoice were not initialized from the model checkpoint at gpt2 and are newly initialized: {('classifier', 'bias'), ('classifier', 'kernel')}\n",
989
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
990
+ ],
991
+ "name": "stderr"
992
+ }
993
+ ]
994
+ },
995
+ {
996
+ "cell_type": "code",
997
+ "metadata": {
998
+ "id": "CdSuQK9pRmw-"
999
+ },
1000
+ "source": [
1001
+ "input_ids=jnp.ones((4,5,6))\n",
1002
+ "attention_mask=jnp.ones((4,5,6))"
1003
+ ],
1004
+ "execution_count": 10,
1005
+ "outputs": []
1006
+ },
1007
+ {
1008
+ "cell_type": "code",
1009
+ "metadata": {
1010
+ "id": "d3Bu38KTkwWs",
1011
+ "colab": {
1012
+ "base_uri": "https://localhost:8080/",
1013
+ "height": 300
1014
+ },
1015
+ "outputId": "5470ccc9-6d49-427c-ad8e-5162343acfde"
1016
+ },
1017
+ "source": [
1018
+ "out1 = model(input_ids, attention_mask) #GPT2 will not take (batch_size,num_choice,sequence_length)"
1019
+ ],
1020
+ "execution_count": 11,
1021
+ "outputs": [
1022
+ {
1023
+ "output_type": "error",
1024
+ "ename": "ValueError",
1025
+ "evalue": "ignored",
1026
+ "traceback": [
1027
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1028
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
1029
+ "\u001b[0;32m<ipython-input-11-7491141e6756>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mout1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;31m#GPT2 will not take (batch_size,num_choice,sequence_length)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
1030
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 370\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mreturn_dict\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mNone\u001b[0m \u001b[0;32melse\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mconfig\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 371\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 372\u001b[0;31m \u001b[0mbatch_size\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0msequence_length\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0minput_ids\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mshape\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 373\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 374\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mposition_ids\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mNone\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
1031
+ "\u001b[0;31mValueError\u001b[0m: too many values to unpack (expected 2)"
1032
+ ]
1033
+ }
1034
+ ]
1035
+ },
1036
+ {
1037
+ "cell_type": "code",
1038
+ "metadata": {
1039
+ "id": "VZPlQfkhgLJd",
1040
+ "colab": {
1041
+ "base_uri": "https://localhost:8080/"
1042
+ },
1043
+ "outputId": "7948688b-be77-4e9a-fc06-8644a2614d42"
1044
+ },
1045
+ "source": [
1046
+ "print(out1)"
1047
+ ],
1048
+ "execution_count": null,
1049
+ "outputs": [
1050
+ {
1051
+ "output_type": "stream",
1052
+ "text": [
1053
+ "[[ 1.1391759 -0.01598702 0.55463445 0.36025363]\n",
1054
+ " [ 0.32208228 0.37667227 0.87823874 0.19541818]\n",
1055
+ " [ 0.76971424 0.7187787 0.68642044 -0.31461257]\n",
1056
+ " [ 1.2375658 0.03325981 0.00153449 0.12019679]]\n"
1057
+ ],
1058
+ "name": "stdout"
1059
+ }
1060
+ ]
1061
+ },
1062
+ {
1063
+ "cell_type": "code",
1064
+ "metadata": {
1065
+ "id": "fgkIcD-mZWP7"
1066
+ },
1067
+ "source": [
1068
+ ""
1069
+ ],
1070
+ "execution_count": null,
1071
+ "outputs": []
1072
+ }
1073
+ ]
1074
+ }
Untitled330.ipynb ADDED
@@ -0,0 +1,470 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "name": "Untitled330.ipynb",
7
+ "provenance": [],
8
+ "collapsed_sections": []
9
+ },
10
+ "kernelspec": {
11
+ "name": "python3",
12
+ "display_name": "Python 3"
13
+ },
14
+ "language_info": {
15
+ "name": "python"
16
+ }
17
+ },
18
+ "cells": [
19
+ {
20
+ "cell_type": "code",
21
+ "metadata": {
22
+ "id": "Ii2x731Ta8fu"
23
+ },
24
+ "source": [
25
+ "%%capture\n",
26
+ "!pip install transformers\n",
27
+ "!pip install datasets\n",
28
+ "!pip install --upgrade git+https://github.com/google/flax.git"
29
+ ],
30
+ "execution_count": 1,
31
+ "outputs": []
32
+ },
33
+ {
34
+ "cell_type": "code",
35
+ "metadata": {
36
+ "id": "_9NMPFKua9hr"
37
+ },
38
+ "source": [
39
+ "import jax\n",
40
+ "from transformers.modeling_flax_utils import FlaxPreTrainedModel\n",
41
+ "import flax.linen as nn\n",
42
+ "import jax.numpy as jnp\n",
43
+ "from transformers import GPT2Config\n",
44
+ "#from transformers import FlaxGPT2PreTrainedModel\n",
45
+ "from transformers import FlaxGPT2Model\n",
46
+ "import jax.numpy as jnp\n",
47
+ "from transformers import GPT2Tokenizer\n",
48
+ "tokenizer = GPT2Tokenizer.from_pretrained(\"gpt2\",pad_token='<|endoftext|>') \n",
49
+ "from typing import Any, Optional, Tuple\n",
50
+ "from flax.core.frozen_dict import FrozenDict, unfreeze\n",
51
+ "from transformers import file_utils\n",
52
+ "from transformers.file_utils import add_start_docstrings, add_start_docstrings_to_model_forward\n",
53
+ "from transformers.models.gpt2.modeling_flax_gpt2 import FlaxGPT2BlockCollection\n",
54
+ "from transformers.modeling_flax_outputs import FlaxBaseModelOutput"
55
+ ],
56
+ "execution_count": null,
57
+ "outputs": []
58
+ },
59
+ {
60
+ "cell_type": "code",
61
+ "metadata": {
62
+ "id": "dqkcoBOccszd"
63
+ },
64
+ "source": [
65
+ "GPT2_START_DOCSTRING = r\"\"\"\n",
66
+ " This model inherits from :class:`~transformers.FlaxPreTrainedModel`. Check the superclass documentation for the\n",
67
+ " generic methods the library implements for all its model (such as downloading or saving, resizing the input\n",
68
+ " embeddings, pruning heads etc.)\n",
69
+ " This model is also a Flax Linen `flax.nn.Module\n",
70
+ " <https://flax.readthedocs.io/en/latest/_autosummary/flax.nn.module.html>`__ subclass. Use it as a regular Flax\n",
71
+ " Module and refer to the Flax documentation for all matter related to general usage and behavior.\n",
72
+ " Finally, this model supports inherent JAX features such as:\n",
73
+ " - `Just-In-Time (JIT) compilation <https://jax.readthedocs.io/en/latest/jax.html#just-in-time-compilation-jit>`__\n",
74
+ " - `Automatic Differentiation <https://jax.readthedocs.io/en/latest/jax.html#automatic-differentiation>`__\n",
75
+ " - `Vectorization <https://jax.readthedocs.io/en/latest/jax.html#vectorization-vmap>`__\n",
76
+ " - `Parallelization <https://jax.readthedocs.io/en/latest/jax.html#parallelization-pmap>`__\n",
77
+ " Parameters:\n",
78
+ " config (:class:`~transformers.GPT2Config`): Model configuration class with all the parameters of the model.\n",
79
+ " Initializing with a config file does not load the weights associated with the model, only the\n",
80
+ " configuration. Check out the :meth:`~transformers.FlaxPreTrainedModel.from_pretrained` method to load the\n",
81
+ " model weights.\n",
82
+ "\"\"\"\n",
83
+ "\n",
84
+ "GPT2_INPUTS_DOCSTRING = r\"\"\"\n",
85
+ " Args:\n",
86
+ " input_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size,input_ids_length)`):\n",
87
+ " :obj:`input_ids_length` = ``sequence_length``. Indices of input sequence tokens in the vocabulary.\n",
88
+ " Indices can be obtained using :class:`~transformers.GPT2Tokenizer`. See\n",
89
+ " :meth:`transformers.PreTrainedTokenizer.encode` and :meth:`transformers.PreTrainedTokenizer.__call__` for\n",
90
+ " details.\n",
91
+ " `What are input IDs? <../glossary.html#input-ids>`__\n",
92
+ " attention_mask (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
93
+ " Mask to avoid performing attention on padding token indices. Mask values selected in ``[0, 1]``:\n",
94
+ " - 1 for tokens that are **not masked**,\n",
95
+ " - 0 for tokens that are **masked**.\n",
96
+ " `What are attention masks? <../glossary.html#attention-mask>`__\n",
97
+ " position_ids (:obj:`numpy.ndarray` of shape :obj:`(batch_size, sequence_length)`, `optional`):\n",
98
+ " Indices of positions of each input sequence tokens in the position embeddings. Selected in the range ``[0,\n",
99
+ " config.max_position_embeddings - 1]``.\n",
100
+ " past_key_values (:obj:`Dict[str, np.ndarray]`, `optional`, returned by ``init_cache`` or when passing previous ``past_key_values``):\n",
101
+ " Dictionary of pre-computed hidden-states (key and values in the attention blocks) that can be used for fast\n",
102
+ " auto-regressive decoding. Pre-computed key and value hidden-states are of shape `[batch_size, max_length]`.\n",
103
+ " output_attentions (:obj:`bool`, `optional`):\n",
104
+ " Whether or not to return the attentions tensors of all attention layers. See ``attentions`` under returned\n",
105
+ " tensors for more detail.\n",
106
+ " output_hidden_states (:obj:`bool`, `optional`):\n",
107
+ " Whether or not to return the hidden states of all layers. See ``hidden_states`` under returned tensors for\n",
108
+ " more detail.\n",
109
+ " return_dict (:obj:`bool`, `optional`):\n",
110
+ " Whether or not to return a :class:`~transformers.file_utils.ModelOutput` instead of a plain tuple.\n",
111
+ "\"\"\""
112
+ ],
113
+ "execution_count": 3,
114
+ "outputs": []
115
+ },
116
+ {
117
+ "cell_type": "code",
118
+ "metadata": {
119
+ "id": "NX-Z5iCMbKL5"
120
+ },
121
+ "source": [
122
+ "class FlaxGGGPreTrainedModel(FlaxPreTrainedModel):\n",
123
+ " \"\"\"\n",
124
+ " An abstract class to handle weights initialization and a simple interface for downloading and loading pretrained\n",
125
+ " models.\n",
126
+ " \"\"\"\n",
127
+ "\n",
128
+ " config_class = GPT2Config\n",
129
+ " base_model_prefix = \"transformer\"\n",
130
+ " module_class: nn.Module = None\n",
131
+ "\n",
132
+ " def __init__(\n",
133
+ " self,\n",
134
+ " config: GPT2Config,\n",
135
+ " input_shape: Tuple = (1,1),\n",
136
+ " seed: int = 0,\n",
137
+ " dtype: jnp.dtype = jnp.float32,\n",
138
+ " **kwargs,\n",
139
+ " ):\n",
140
+ " \n",
141
+ " module = self.module_class(config=config, dtype=dtype, **kwargs)\n",
142
+ " super().__init__(config, module, input_shape=input_shape, seed=seed, dtype=dtype)\n",
143
+ "\n",
144
+ " def init_weights(self, rng: jax.random.PRNGKey, input_shape: Tuple) -> FrozenDict:\n",
145
+ " # init input tensors\n",
146
+ " input_ids = jnp.zeros(input_shape, dtype=\"i4\")\n",
147
+ " attention_mask = jnp.ones_like(input_ids)\n",
148
+ " \n",
149
+ " position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_shape)\n",
150
+ "\n",
151
+ " params_rng, dropout_rng = jax.random.split(rng)\n",
152
+ " rngs = {\"params\": params_rng, \"dropout\": dropout_rng}\n",
153
+ "\n",
154
+ " return self.module.init(rngs, input_ids, attention_mask, position_ids, return_dict=False)[\"params\"]\n",
155
+ "\n",
156
+ " def init_cache(self, batch_size, max_length):\n",
157
+ " r\"\"\"\n",
158
+ " Args:\n",
159
+ " batch_size (:obj:`int`):\n",
160
+ " batch_size used for fast auto-regressive decoding. Defines the batch size of the initialized cache.\n",
161
+ " max_length (:obj:`int`):\n",
162
+ " maximum possible length for auto-regressive decoding. Defines the sequence length of the initialized\n",
163
+ " cache.\n",
164
+ " \"\"\"\n",
165
+ " # init input variables to retrieve cache\n",
166
+ " input_ids = jnp.ones((batch_size, max_length))\n",
167
+ " attention_mask = jnp.ones_like(input_ids)\n",
168
+ " position_ids = jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)\n",
169
+ "\n",
170
+ " init_variables = self.module.init(\n",
171
+ " jax.random.PRNGKey(0), input_ids, attention_mask, position_ids, return_dict=False, init_cache=True\n",
172
+ " )\n",
173
+ " return init_variables[\"cache\"]\n",
174
+ "\n",
175
+ " @add_start_docstrings_to_model_forward(GPT2_INPUTS_DOCSTRING)\n",
176
+ " def __call__(\n",
177
+ " self,\n",
178
+ " input_ids,\n",
179
+ " attention_mask=None,\n",
180
+ " position_ids=None,\n",
181
+ " params: dict = None,\n",
182
+ " past_key_values: dict = None,\n",
183
+ " dropout_rng: jax.random.PRNGKey = None,\n",
184
+ " train: bool = False,\n",
185
+ " output_attentions: Optional[bool] = None,\n",
186
+ " output_hidden_states: Optional[bool] = None,\n",
187
+ " return_dict: Optional[bool] = None,\n",
188
+ " ):\n",
189
+ " output_attentions = output_attentions if output_attentions is not None else self.config.output_attentions\n",
190
+ " output_hidden_states = (\n",
191
+ " output_hidden_states if output_hidden_states is not None else self.config.output_hidden_states\n",
192
+ " )\n",
193
+ " return_dict = return_dict if return_dict is not None else self.config.return_dict\n",
194
+ " print(input_ids.shape)\n",
195
+ "\n",
196
+ " # batch_size, num_choices,sequence_length = input_ids.shape\n",
197
+ "\n",
198
+ " if position_ids is None:\n",
199
+ " if past_key_values is not None:\n",
200
+ " raise ValueError(\"Make sure to provide `position_ids` when passing `past_key_values`.\")\n",
201
+ " \n",
202
+ " position_ids=jnp.broadcast_to(jnp.arange(jnp.atleast_2d(input_ids).shape[-1]), input_ids.shape)\n",
203
+ "\n",
204
+ " # position_ids = jnp.broadcast_to(jnp.arange(sequence_length)[None, :], (batch_size, sequence_length))\n",
205
+ "\n",
206
+ " if attention_mask is None:\n",
207
+ " attention_mask = jnp.ones((input_ids))\n",
208
+ " print('attn not')\n",
209
+ "\n",
210
+ " # Handle any PRNG if needed\n",
211
+ " rngs = {}\n",
212
+ " if dropout_rng is not None:\n",
213
+ " rngs[\"dropout\"] = dropout_rng\n",
214
+ "\n",
215
+ " inputs = {\"params\": params or self.params}\n",
216
+ "\n",
217
+ " # if past_key_values are passed then cache is already initialized a private flag init_cache has to be passed down to ensure cache is used. It has to be made sure that cache is marked as mutable so that it can be changed by FlaxGPT2Attention module\n",
218
+ " if past_key_values:\n",
219
+ " inputs[\"cache\"] = past_key_values\n",
220
+ " mutable = [\"cache\"]\n",
221
+ " else:\n",
222
+ " mutable = False\n",
223
+ "\n",
224
+ " outputs = self.module.apply(\n",
225
+ " inputs,\n",
226
+ " jnp.array(input_ids, dtype=\"i4\"),\n",
227
+ " jnp.array(attention_mask, dtype=\"i4\"),\n",
228
+ " jnp.array(position_ids, dtype=\"i4\"),\n",
229
+ " not train,\n",
230
+ " False,\n",
231
+ " output_attentions,\n",
232
+ " output_hidden_states,\n",
233
+ " return_dict,\n",
234
+ " rngs=rngs,\n",
235
+ " mutable=mutable,\n",
236
+ " )\n",
237
+ " print('cache')\n",
238
+ "\n",
239
+ " # add updated cache to model output\n",
240
+ " if past_key_values is not None and return_dict:\n",
241
+ " outputs, past_key_values = outputs\n",
242
+ " outputs[\"past_key_values\"] = unfreeze(past_key_values[\"cache\"])\n",
243
+ " return outputs\n",
244
+ " elif past_key_values is not None and not return_dict:\n",
245
+ " outputs, past_key_values = outputs\n",
246
+ " outputs = outputs[:1] + (unfreeze(past_key_values[\"cache\"]),) + outputs[1:]\n",
247
+ "\n",
248
+ " return outputs"
249
+ ],
250
+ "execution_count": 6,
251
+ "outputs": []
252
+ },
253
+ {
254
+ "cell_type": "code",
255
+ "metadata": {
256
+ "id": "4vRAWll2bwQQ"
257
+ },
258
+ "source": [
259
+ "class FlaxGGGModule(nn.Module):\n",
260
+ " config: GPT2Config\n",
261
+ " dtype: jnp.dtype = jnp.float32\n",
262
+ "\n",
263
+ " def setup(self):\n",
264
+ " self.embed_dim = self.config.hidden_size\n",
265
+ "\n",
266
+ " self.wte = nn.Embed(\n",
267
+ " self.config.vocab_size,\n",
268
+ " self.embed_dim,\n",
269
+ " embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n",
270
+ " dtype=self.dtype,\n",
271
+ " )\n",
272
+ " self.wpe = nn.Embed(\n",
273
+ " self.config.max_position_embeddings,\n",
274
+ " self.embed_dim,\n",
275
+ " embedding_init=jax.nn.initializers.normal(stddev=self.config.initializer_range),\n",
276
+ " dtype=self.dtype,\n",
277
+ " )\n",
278
+ " self.dropout = nn.Dropout(rate=self.config.embd_pdrop)\n",
279
+ " self.h = FlaxGPT2BlockCollection(self.config, dtype=self.dtype)\n",
280
+ " self.ln_f = nn.LayerNorm(epsilon=self.config.layer_norm_epsilon, dtype=self.dtype)\n",
281
+ "\n",
282
+ " def __call__(\n",
283
+ " self,\n",
284
+ " input_ids,\n",
285
+ " attention_mask,\n",
286
+ " position_ids,\n",
287
+ " deterministic=True,\n",
288
+ " init_cache: bool = False,\n",
289
+ " output_attentions: bool = False,\n",
290
+ " output_hidden_states: bool = False,\n",
291
+ " return_dict: bool = True,\n",
292
+ " ):\n",
293
+ " input_embeds = self.wte(input_ids.astype(\"i4\"))\n",
294
+ " position_embeds = self.wpe(position_ids.astype(\"i4\"))\n",
295
+ " \n",
296
+ "\n",
297
+ " hidden_states = input_embeds + position_embeds\n",
298
+ " hidden_states = self.dropout(hidden_states, deterministic=deterministic)\n",
299
+ " outputs = self.h(\n",
300
+ " hidden_states,\n",
301
+ " attention_mask,\n",
302
+ " deterministic=deterministic,\n",
303
+ " init_cache=init_cache,\n",
304
+ " output_attentions=output_attentions,\n",
305
+ " output_hidden_states=output_hidden_states,\n",
306
+ " return_dict=return_dict,\n",
307
+ " )\n",
308
+ "\n",
309
+ " hidden_states = outputs[0]\n",
310
+ " hidden_states = self.ln_f(hidden_states)\n",
311
+ " print('ggg')\n",
312
+ " if not return_dict:\n",
313
+ " return (hidden_states,) + outputs[1:]\n",
314
+ "\n",
315
+ " return FlaxBaseModelOutput(\n",
316
+ " last_hidden_state=hidden_states,\n",
317
+ " hidden_states=outputs.hidden_states,\n",
318
+ " attentions=outputs.attentions,)\n",
319
+ "class FlaxNewModel(FlaxGGGPreTrainedModel):\n",
320
+ " module_class = FlaxGGGModule"
321
+ ],
322
+ "execution_count": 7,
323
+ "outputs": []
324
+ },
325
+ {
326
+ "cell_type": "code",
327
+ "metadata": {
328
+ "id": "_ljSn6GdedtI"
329
+ },
330
+ "source": [
331
+ "class FlaxGPT2ForMultipleChoiceModule(nn.Module):\n",
332
+ " config:GPT2Config\n",
333
+ " dtype: jnp.dtype = jnp.float32\n",
334
+ " def setup(self):\n",
335
+ " self.gpt2 = FlaxNewModel(config=self.config, dtype=self.dtype)\n",
336
+ " self.dropout = nn.Dropout(rate=0.2)\n",
337
+ " self.classifier = nn.Dense(4, dtype=self.dtype)\n",
338
+ "\n",
339
+ " def __call__(self,input_ids,attention_mask,position_ids,return_dict=True,deterministic=True,*args):\n",
340
+ " batch_size = input_ids.shape[0]\n",
341
+ " rng=jax.random.PRNGKey(0)\n",
342
+ " _, dropout_rng = jax.random.split(rng)\n",
343
+ " print('abc')\n",
344
+ "\n",
345
+ " outputs=self.gpt2(input_ids, attention_mask,position_ids,return_dict=return_dict)\n",
346
+ " \n",
347
+ "\n",
348
+ " hidden_states = outputs[0]\n",
349
+ "\n",
350
+ " \n",
351
+ " hidden_states= jnp.mean(hidden_states, axis=1)\n",
352
+ "\n",
353
+ " print(hidden_states.shape)\n",
354
+ " \n",
355
+ " \n",
356
+ " hidden_states=hidden_states.reshape(batch_size,-1) #(32,8,768)->(32,8*768)\n",
357
+ "\n",
358
+ " dropout_output = self.dropout(hidden_states,deterministic=deterministic,rng=dropout_rng)\n",
359
+ "\n",
360
+ " print(dropout_output.shape)\n",
361
+ " \n",
362
+ "\n",
363
+ " logits = self.classifier(dropout_output)\n",
364
+ " print('bnv')\n",
365
+ " reshaped_logits = logits.reshape(-1, 4) \n",
366
+ " #(32,4)\n",
367
+ " if not return_dict:\n",
368
+ " return (reshaped_logits,) + outputs[2:]\n",
369
+ " return reshaped_logits"
370
+ ],
371
+ "execution_count": 8,
372
+ "outputs": []
373
+ },
374
+ {
375
+ "cell_type": "code",
376
+ "metadata": {
377
+ "id": "M4UPf3Waexq0"
378
+ },
379
+ "source": [
380
+ "class FlaxGPT2ForMultipleChoice(FlaxNewModel):\n",
381
+ " module_class = FlaxGPT2ForMultipleChoiceModule"
382
+ ],
383
+ "execution_count": 9,
384
+ "outputs": []
385
+ },
386
+ {
387
+ "cell_type": "code",
388
+ "metadata": {
389
+ "id": "roQ3vls4e4TH"
390
+ },
391
+ "source": [
392
+ "model = FlaxGPT2ForMultipleChoice.from_pretrained('gpt2')"
393
+ ],
394
+ "execution_count": null,
395
+ "outputs": []
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "metadata": {
400
+ "id": "E9qOSaaie417"
401
+ },
402
+ "source": [
403
+ "input_ids=jnp.ones((1,2,11))\n",
404
+ "attention_mask=jnp.ones((1,2,11))"
405
+ ],
406
+ "execution_count": 12,
407
+ "outputs": []
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "metadata": {
412
+ "colab": {
413
+ "base_uri": "https://localhost:8080/",
414
+ "height": 409
415
+ },
416
+ "id": "am7hYv8auWVy",
417
+ "outputId": "0c8192ca-a0ab-432e-d483-46f8a2cc2576"
418
+ },
419
+ "source": [
420
+ "out1 = model(input_ids, attention_mask)"
421
+ ],
422
+ "execution_count": 13,
423
+ "outputs": [
424
+ {
425
+ "output_type": "stream",
426
+ "text": [
427
+ "(1, 2, 11)\n",
428
+ "attn not\n",
429
+ "ggg\n",
430
+ "abc\n",
431
+ "(1, 2, 11)\n",
432
+ "attn not\n"
433
+ ],
434
+ "name": "stdout"
435
+ },
436
+ {
437
+ "output_type": "error",
438
+ "ename": "ValueError",
439
+ "evalue": "ignored",
440
+ "traceback": [
441
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
442
+ "\u001b[0;31mValueError\u001b[0m Traceback (most recent call last)",
443
+ "\u001b[0;32m<ipython-input-13-6be36035677e>\u001b[0m in \u001b[0;36m<module>\u001b[0;34m()\u001b[0m\n\u001b[0;32m----> 1\u001b[0;31m \u001b[0mout1\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mmodel\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m",
444
+ "\u001b[0;32m<ipython-input-6-de553f26d169>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m )\n\u001b[1;32m 116\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cache'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
445
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 965\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 966\u001b[0;31m )(variables, *args, **kwargs, rngs=rngs)\n\u001b[0m\u001b[1;32m 967\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 968\u001b[0m def init_with_output(self,\n",
446
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/scope.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(variables, rngs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 685\u001b[0m **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:\n\u001b[1;32m 686\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvariables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtemporary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 687\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 688\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmutable\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 689\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmutable_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
447
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mscope_fn\u001b[0;34m(scope, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1214\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1216\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1217\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1218\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
448
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
449
+ "\u001b[0;32m<ipython-input-8-2c21e4c966c8>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, return_dict, deterministic, *args)\u001b[0m\n\u001b[1;32m 13\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'abc'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 15\u001b[0;31m \u001b[0moutputs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mgpt2\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0minput_ids\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mattention_mask\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mposition_ids\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 16\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
450
+ "\u001b[0;32m<ipython-input-6-de553f26d169>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, params, past_key_values, dropout_rng, train, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 112\u001b[0m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 113\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 114\u001b[0;31m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 115\u001b[0m )\n\u001b[1;32m 116\u001b[0m \u001b[0mprint\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m'cache'\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
451
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mapply\u001b[0;34m(self, variables, rngs, method, mutable, capture_intermediates, *args, **kwargs)\u001b[0m\n\u001b[1;32m 964\u001b[0m \u001b[0mmethod\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 965\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 966\u001b[0;31m )(variables, *args, **kwargs, rngs=rngs)\n\u001b[0m\u001b[1;32m 967\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 968\u001b[0m def init_with_output(self,\n",
452
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/core/scope.py\u001b[0m in \u001b[0;36mwrapper\u001b[0;34m(variables, rngs, *args, **kwargs)\u001b[0m\n\u001b[1;32m 685\u001b[0m **kwargs) -> Union[Any, Tuple[Any, VariableDict]]:\n\u001b[1;32m 686\u001b[0m \u001b[0;32mwith\u001b[0m \u001b[0mbind\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mvariables\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mrngs\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mrngs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mmutable\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mmutable\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mtemporary\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m \u001b[0;32mas\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 687\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mroot\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 688\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0mmutable\u001b[0m \u001b[0;32mis\u001b[0m \u001b[0;32mnot\u001b[0m \u001b[0;32mFalse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 689\u001b[0m \u001b[0;32mreturn\u001b[0m \u001b[0my\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mroot\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmutable_variables\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
453
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mscope_fn\u001b[0;34m(scope, *args, **kwargs)\u001b[0m\n\u001b[1;32m 1214\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mcapture_intermediates\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1215\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1216\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0mfn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mmodule\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mclone\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mparent\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mscope\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1217\u001b[0m \u001b[0;32mfinally\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1218\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mpop\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
454
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
455
+ "\u001b[0;32m<ipython-input-7-b2eaa3f7b251>\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, input_ids, attention_mask, position_ids, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 47\u001b[0m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 48\u001b[0m \u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_hidden_states\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m---> 49\u001b[0;31m \u001b[0mreturn_dict\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mreturn_dict\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 50\u001b[0m )\n\u001b[1;32m 51\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n",
456
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
457
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, hidden_states, attention_mask, deterministic, init_cache, output_attentions, output_hidden_states, return_dict)\u001b[0m\n\u001b[1;32m 452\u001b[0m \u001b[0mdeterministic\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdeterministic\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 453\u001b[0m \u001b[0minit_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 454\u001b[0;31m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 455\u001b[0m )\n\u001b[1;32m 456\u001b[0m \u001b[0mhidden_states\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mlayer_outputs\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
458
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
459
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, hidden_states, attention_mask, deterministic, init_cache, output_attentions)\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mdeterministic\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0mdeterministic\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 286\u001b[0m \u001b[0minit_cache\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0minit_cache\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 287\u001b[0;31m \u001b[0moutput_attentions\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0moutput_attentions\u001b[0m\u001b[0;34m,\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 288\u001b[0m )\n\u001b[1;32m 289\u001b[0m \u001b[0;31m# residual connection\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
460
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/flax/linen/module.py\u001b[0m in \u001b[0;36mwrapped_module_method\u001b[0;34m(*args, **kwargs)\u001b[0m\n\u001b[1;32m 281\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mmodule_stack\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mappend\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 282\u001b[0m \u001b[0;32mtry\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 283\u001b[0;31m \u001b[0my\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mfun\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mself\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m*\u001b[0m\u001b[0margs\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m**\u001b[0m\u001b[0mkwargs\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 284\u001b[0m \u001b[0;32mif\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 285\u001b[0m \u001b[0mfilter_fn\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0m_context\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mcapture_stack\u001b[0m\u001b[0;34m[\u001b[0m\u001b[0;34m-\u001b[0m\u001b[0;36m1\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
461
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/transformers/models/gpt2/modeling_flax_gpt2.py\u001b[0m in \u001b[0;36m__call__\u001b[0;34m(self, hidden_states, attention_mask, deterministic, init_cache, output_attentions)\u001b[0m\n\u001b[1;32m 177\u001b[0m ):\n\u001b[1;32m 178\u001b[0m \u001b[0mqkv_out\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0mc_attn\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mhidden_states\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m--> 179\u001b[0;31m \u001b[0mquery\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mkey\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mvalue\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mjnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mqkv_out\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;36m3\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0;36m2\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 180\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 181\u001b[0m \u001b[0mquery\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0mself\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0m_split_heads\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mquery\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
462
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36msplit\u001b[0;34m(ary, indices_or_sections, axis)\u001b[0m\n\u001b[1;32m 1806\u001b[0m \u001b[0;34m@\u001b[0m\u001b[0m_wraps\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp\u001b[0m\u001b[0;34m.\u001b[0m\u001b[0msplit\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1807\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0msplit\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices_or_sections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0mint\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;36m0\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1808\u001b[0;31m \u001b[0;32mreturn\u001b[0m \u001b[0m_split\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"split\"\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mary\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mindices_or_sections\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m=\u001b[0m\u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1809\u001b[0m \u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1810\u001b[0m \u001b[0;32mdef\u001b[0m \u001b[0m_split_on_axis\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mnp_fun\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0maxis\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
463
+ "\u001b[0;32m/usr/local/lib/python3.7/dist-packages/jax/_src/numpy/lax_numpy.py\u001b[0m in \u001b[0;36m_split\u001b[0;34m(op, ary, indices_or_sections, axis)\u001b[0m\n\u001b[1;32m 1798\u001b[0m + ((r + 1) * (part_size + 1) - 1)])\n\u001b[1;32m 1799\u001b[0m \u001b[0;32melse\u001b[0m\u001b[0;34m:\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0;32m-> 1800\u001b[0;31m \u001b[0;32mraise\u001b[0m \u001b[0mValueError\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0;34m\"array split does not result in an equal division\"\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[0m\u001b[1;32m 1801\u001b[0m \u001b[0mstarts\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mends\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;36m0\u001b[0m\u001b[0;34m]\u001b[0m \u001b[0;34m*\u001b[0m \u001b[0mndim\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mary\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mshape\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mary\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n\u001b[1;32m 1802\u001b[0m \u001b[0m_subval\u001b[0m \u001b[0;34m=\u001b[0m \u001b[0;32mlambda\u001b[0m \u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m:\u001b[0m \u001b[0msubvals\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mx\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0;34m[\u001b[0m\u001b[0;34m(\u001b[0m\u001b[0mi\u001b[0m\u001b[0;34m,\u001b[0m \u001b[0mv\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m]\u001b[0m\u001b[0;34m)\u001b[0m\u001b[0;34m\u001b[0m\u001b[0;34m\u001b[0m\u001b[0m\n",
464
+ "\u001b[0;31mValueError\u001b[0m: array split does not result in an equal division"
465
+ ]
466
+ }
467
+ ]
468
+ }
469
+ ]
470
+ }