karanzrk commited on
Commit
aef2b93
1 Parent(s): 87f0430

Added training script

Browse files
Files changed (2) hide show
  1. ML Canvas Group 7.pdf +0 -0
  2. training.ipynb +1625 -0
ML Canvas Group 7.pdf ADDED
Binary file (110 kB). View file
 
training.ipynb ADDED
@@ -0,0 +1,1625 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "nbformat": 4,
3
+ "nbformat_minor": 0,
4
+ "metadata": {
5
+ "colab": {
6
+ "provenance": [],
7
+ "gpuType": "T4"
8
+ },
9
+ "kernelspec": {
10
+ "name": "python3",
11
+ "display_name": "Python 3"
12
+ },
13
+ "language_info": {
14
+ "name": "python"
15
+ },
16
+ "accelerator": "GPU",
17
+ "widgets": {
18
+ "application/vnd.jupyter.widget-state+json": {
19
+ "e68b6e6997844bf788a057f9c7feedfb": {
20
+ "model_module": "@jupyter-widgets/controls",
21
+ "model_name": "HBoxModel",
22
+ "model_module_version": "1.5.0",
23
+ "state": {
24
+ "_dom_classes": [],
25
+ "_model_module": "@jupyter-widgets/controls",
26
+ "_model_module_version": "1.5.0",
27
+ "_model_name": "HBoxModel",
28
+ "_view_count": null,
29
+ "_view_module": "@jupyter-widgets/controls",
30
+ "_view_module_version": "1.5.0",
31
+ "_view_name": "HBoxView",
32
+ "box_style": "",
33
+ "children": [
34
+ "IPY_MODEL_295e4080ccd64e48806a36b83e50ddfa",
35
+ "IPY_MODEL_c4025862f06b412cb99165b67ad7daae",
36
+ "IPY_MODEL_5ac369dab692489cb13cdb664c47fd96"
37
+ ],
38
+ "layout": "IPY_MODEL_434aa0b7bd76440d9b9b64d8b53133d3"
39
+ }
40
+ },
41
+ "295e4080ccd64e48806a36b83e50ddfa": {
42
+ "model_module": "@jupyter-widgets/controls",
43
+ "model_name": "HTMLModel",
44
+ "model_module_version": "1.5.0",
45
+ "state": {
46
+ "_dom_classes": [],
47
+ "_model_module": "@jupyter-widgets/controls",
48
+ "_model_module_version": "1.5.0",
49
+ "_model_name": "HTMLModel",
50
+ "_view_count": null,
51
+ "_view_module": "@jupyter-widgets/controls",
52
+ "_view_module_version": "1.5.0",
53
+ "_view_name": "HTMLView",
54
+ "description": "",
55
+ "description_tooltip": null,
56
+ "layout": "IPY_MODEL_9e2a1fea814f408ebb4d15db83b1130b",
57
+ "placeholder": "​",
58
+ "style": "IPY_MODEL_4a2f178864244d68bd915ee57379251d",
59
+ "value": "Map: 100%"
60
+ }
61
+ },
62
+ "c4025862f06b412cb99165b67ad7daae": {
63
+ "model_module": "@jupyter-widgets/controls",
64
+ "model_name": "FloatProgressModel",
65
+ "model_module_version": "1.5.0",
66
+ "state": {
67
+ "_dom_classes": [],
68
+ "_model_module": "@jupyter-widgets/controls",
69
+ "_model_module_version": "1.5.0",
70
+ "_model_name": "FloatProgressModel",
71
+ "_view_count": null,
72
+ "_view_module": "@jupyter-widgets/controls",
73
+ "_view_module_version": "1.5.0",
74
+ "_view_name": "ProgressView",
75
+ "bar_style": "success",
76
+ "description": "",
77
+ "description_tooltip": null,
78
+ "layout": "IPY_MODEL_7125f94d482a46999fd4dd3be1b3e87e",
79
+ "max": 1148,
80
+ "min": 0,
81
+ "orientation": "horizontal",
82
+ "style": "IPY_MODEL_96486cdef9714482a4ffa2aca1b3628b",
83
+ "value": 1148
84
+ }
85
+ },
86
+ "5ac369dab692489cb13cdb664c47fd96": {
87
+ "model_module": "@jupyter-widgets/controls",
88
+ "model_name": "HTMLModel",
89
+ "model_module_version": "1.5.0",
90
+ "state": {
91
+ "_dom_classes": [],
92
+ "_model_module": "@jupyter-widgets/controls",
93
+ "_model_module_version": "1.5.0",
94
+ "_model_name": "HTMLModel",
95
+ "_view_count": null,
96
+ "_view_module": "@jupyter-widgets/controls",
97
+ "_view_module_version": "1.5.0",
98
+ "_view_name": "HTMLView",
99
+ "description": "",
100
+ "description_tooltip": null,
101
+ "layout": "IPY_MODEL_2364eb3ce5b345788902c5f9d316a00a",
102
+ "placeholder": "​",
103
+ "style": "IPY_MODEL_52f799ea10d4403cb18e33ba80d739d3",
104
+ "value": " 1148/1148 [00:01<00:00, 781.70 examples/s]"
105
+ }
106
+ },
107
+ "434aa0b7bd76440d9b9b64d8b53133d3": {
108
+ "model_module": "@jupyter-widgets/base",
109
+ "model_name": "LayoutModel",
110
+ "model_module_version": "1.2.0",
111
+ "state": {
112
+ "_model_module": "@jupyter-widgets/base",
113
+ "_model_module_version": "1.2.0",
114
+ "_model_name": "LayoutModel",
115
+ "_view_count": null,
116
+ "_view_module": "@jupyter-widgets/base",
117
+ "_view_module_version": "1.2.0",
118
+ "_view_name": "LayoutView",
119
+ "align_content": null,
120
+ "align_items": null,
121
+ "align_self": null,
122
+ "border": null,
123
+ "bottom": null,
124
+ "display": null,
125
+ "flex": null,
126
+ "flex_flow": null,
127
+ "grid_area": null,
128
+ "grid_auto_columns": null,
129
+ "grid_auto_flow": null,
130
+ "grid_auto_rows": null,
131
+ "grid_column": null,
132
+ "grid_gap": null,
133
+ "grid_row": null,
134
+ "grid_template_areas": null,
135
+ "grid_template_columns": null,
136
+ "grid_template_rows": null,
137
+ "height": null,
138
+ "justify_content": null,
139
+ "justify_items": null,
140
+ "left": null,
141
+ "margin": null,
142
+ "max_height": null,
143
+ "max_width": null,
144
+ "min_height": null,
145
+ "min_width": null,
146
+ "object_fit": null,
147
+ "object_position": null,
148
+ "order": null,
149
+ "overflow": null,
150
+ "overflow_x": null,
151
+ "overflow_y": null,
152
+ "padding": null,
153
+ "right": null,
154
+ "top": null,
155
+ "visibility": null,
156
+ "width": null
157
+ }
158
+ },
159
+ "9e2a1fea814f408ebb4d15db83b1130b": {
160
+ "model_module": "@jupyter-widgets/base",
161
+ "model_name": "LayoutModel",
162
+ "model_module_version": "1.2.0",
163
+ "state": {
164
+ "_model_module": "@jupyter-widgets/base",
165
+ "_model_module_version": "1.2.0",
166
+ "_model_name": "LayoutModel",
167
+ "_view_count": null,
168
+ "_view_module": "@jupyter-widgets/base",
169
+ "_view_module_version": "1.2.0",
170
+ "_view_name": "LayoutView",
171
+ "align_content": null,
172
+ "align_items": null,
173
+ "align_self": null,
174
+ "border": null,
175
+ "bottom": null,
176
+ "display": null,
177
+ "flex": null,
178
+ "flex_flow": null,
179
+ "grid_area": null,
180
+ "grid_auto_columns": null,
181
+ "grid_auto_flow": null,
182
+ "grid_auto_rows": null,
183
+ "grid_column": null,
184
+ "grid_gap": null,
185
+ "grid_row": null,
186
+ "grid_template_areas": null,
187
+ "grid_template_columns": null,
188
+ "grid_template_rows": null,
189
+ "height": null,
190
+ "justify_content": null,
191
+ "justify_items": null,
192
+ "left": null,
193
+ "margin": null,
194
+ "max_height": null,
195
+ "max_width": null,
196
+ "min_height": null,
197
+ "min_width": null,
198
+ "object_fit": null,
199
+ "object_position": null,
200
+ "order": null,
201
+ "overflow": null,
202
+ "overflow_x": null,
203
+ "overflow_y": null,
204
+ "padding": null,
205
+ "right": null,
206
+ "top": null,
207
+ "visibility": null,
208
+ "width": null
209
+ }
210
+ },
211
+ "4a2f178864244d68bd915ee57379251d": {
212
+ "model_module": "@jupyter-widgets/controls",
213
+ "model_name": "DescriptionStyleModel",
214
+ "model_module_version": "1.5.0",
215
+ "state": {
216
+ "_model_module": "@jupyter-widgets/controls",
217
+ "_model_module_version": "1.5.0",
218
+ "_model_name": "DescriptionStyleModel",
219
+ "_view_count": null,
220
+ "_view_module": "@jupyter-widgets/base",
221
+ "_view_module_version": "1.2.0",
222
+ "_view_name": "StyleView",
223
+ "description_width": ""
224
+ }
225
+ },
226
+ "7125f94d482a46999fd4dd3be1b3e87e": {
227
+ "model_module": "@jupyter-widgets/base",
228
+ "model_name": "LayoutModel",
229
+ "model_module_version": "1.2.0",
230
+ "state": {
231
+ "_model_module": "@jupyter-widgets/base",
232
+ "_model_module_version": "1.2.0",
233
+ "_model_name": "LayoutModel",
234
+ "_view_count": null,
235
+ "_view_module": "@jupyter-widgets/base",
236
+ "_view_module_version": "1.2.0",
237
+ "_view_name": "LayoutView",
238
+ "align_content": null,
239
+ "align_items": null,
240
+ "align_self": null,
241
+ "border": null,
242
+ "bottom": null,
243
+ "display": null,
244
+ "flex": null,
245
+ "flex_flow": null,
246
+ "grid_area": null,
247
+ "grid_auto_columns": null,
248
+ "grid_auto_flow": null,
249
+ "grid_auto_rows": null,
250
+ "grid_column": null,
251
+ "grid_gap": null,
252
+ "grid_row": null,
253
+ "grid_template_areas": null,
254
+ "grid_template_columns": null,
255
+ "grid_template_rows": null,
256
+ "height": null,
257
+ "justify_content": null,
258
+ "justify_items": null,
259
+ "left": null,
260
+ "margin": null,
261
+ "max_height": null,
262
+ "max_width": null,
263
+ "min_height": null,
264
+ "min_width": null,
265
+ "object_fit": null,
266
+ "object_position": null,
267
+ "order": null,
268
+ "overflow": null,
269
+ "overflow_x": null,
270
+ "overflow_y": null,
271
+ "padding": null,
272
+ "right": null,
273
+ "top": null,
274
+ "visibility": null,
275
+ "width": null
276
+ }
277
+ },
278
+ "96486cdef9714482a4ffa2aca1b3628b": {
279
+ "model_module": "@jupyter-widgets/controls",
280
+ "model_name": "ProgressStyleModel",
281
+ "model_module_version": "1.5.0",
282
+ "state": {
283
+ "_model_module": "@jupyter-widgets/controls",
284
+ "_model_module_version": "1.5.0",
285
+ "_model_name": "ProgressStyleModel",
286
+ "_view_count": null,
287
+ "_view_module": "@jupyter-widgets/base",
288
+ "_view_module_version": "1.2.0",
289
+ "_view_name": "StyleView",
290
+ "bar_color": null,
291
+ "description_width": ""
292
+ }
293
+ },
294
+ "2364eb3ce5b345788902c5f9d316a00a": {
295
+ "model_module": "@jupyter-widgets/base",
296
+ "model_name": "LayoutModel",
297
+ "model_module_version": "1.2.0",
298
+ "state": {
299
+ "_model_module": "@jupyter-widgets/base",
300
+ "_model_module_version": "1.2.0",
301
+ "_model_name": "LayoutModel",
302
+ "_view_count": null,
303
+ "_view_module": "@jupyter-widgets/base",
304
+ "_view_module_version": "1.2.0",
305
+ "_view_name": "LayoutView",
306
+ "align_content": null,
307
+ "align_items": null,
308
+ "align_self": null,
309
+ "border": null,
310
+ "bottom": null,
311
+ "display": null,
312
+ "flex": null,
313
+ "flex_flow": null,
314
+ "grid_area": null,
315
+ "grid_auto_columns": null,
316
+ "grid_auto_flow": null,
317
+ "grid_auto_rows": null,
318
+ "grid_column": null,
319
+ "grid_gap": null,
320
+ "grid_row": null,
321
+ "grid_template_areas": null,
322
+ "grid_template_columns": null,
323
+ "grid_template_rows": null,
324
+ "height": null,
325
+ "justify_content": null,
326
+ "justify_items": null,
327
+ "left": null,
328
+ "margin": null,
329
+ "max_height": null,
330
+ "max_width": null,
331
+ "min_height": null,
332
+ "min_width": null,
333
+ "object_fit": null,
334
+ "object_position": null,
335
+ "order": null,
336
+ "overflow": null,
337
+ "overflow_x": null,
338
+ "overflow_y": null,
339
+ "padding": null,
340
+ "right": null,
341
+ "top": null,
342
+ "visibility": null,
343
+ "width": null
344
+ }
345
+ },
346
+ "52f799ea10d4403cb18e33ba80d739d3": {
347
+ "model_module": "@jupyter-widgets/controls",
348
+ "model_name": "DescriptionStyleModel",
349
+ "model_module_version": "1.5.0",
350
+ "state": {
351
+ "_model_module": "@jupyter-widgets/controls",
352
+ "_model_module_version": "1.5.0",
353
+ "_model_name": "DescriptionStyleModel",
354
+ "_view_count": null,
355
+ "_view_module": "@jupyter-widgets/base",
356
+ "_view_module_version": "1.2.0",
357
+ "_view_name": "StyleView",
358
+ "description_width": ""
359
+ }
360
+ },
361
+ "3e18acb6f1504f4dace716a96e8d90f4": {
362
+ "model_module": "@jupyter-widgets/controls",
363
+ "model_name": "HBoxModel",
364
+ "model_module_version": "1.5.0",
365
+ "state": {
366
+ "_dom_classes": [],
367
+ "_model_module": "@jupyter-widgets/controls",
368
+ "_model_module_version": "1.5.0",
369
+ "_model_name": "HBoxModel",
370
+ "_view_count": null,
371
+ "_view_module": "@jupyter-widgets/controls",
372
+ "_view_module_version": "1.5.0",
373
+ "_view_name": "HBoxView",
374
+ "box_style": "",
375
+ "children": [
376
+ "IPY_MODEL_953e7d76140e4ed2ade688ccd5467a75",
377
+ "IPY_MODEL_3a70d75b4eb949598e7cb9430acfcf81",
378
+ "IPY_MODEL_54719990ff1f40cb8fed06badb378d01"
379
+ ],
380
+ "layout": "IPY_MODEL_5d1be2eaa2c143bbbc35f7d0f33f64de"
381
+ }
382
+ },
383
+ "953e7d76140e4ed2ade688ccd5467a75": {
384
+ "model_module": "@jupyter-widgets/controls",
385
+ "model_name": "HTMLModel",
386
+ "model_module_version": "1.5.0",
387
+ "state": {
388
+ "_dom_classes": [],
389
+ "_model_module": "@jupyter-widgets/controls",
390
+ "_model_module_version": "1.5.0",
391
+ "_model_name": "HTMLModel",
392
+ "_view_count": null,
393
+ "_view_module": "@jupyter-widgets/controls",
394
+ "_view_module_version": "1.5.0",
395
+ "_view_name": "HTMLView",
396
+ "description": "",
397
+ "description_tooltip": null,
398
+ "layout": "IPY_MODEL_002c9d35efa54fccb875a08e7059997f",
399
+ "placeholder": "​",
400
+ "style": "IPY_MODEL_21dd8d7b7e5a4e27922ff1e3bec7745a",
401
+ "value": "Map: 100%"
402
+ }
403
+ },
404
+ "3a70d75b4eb949598e7cb9430acfcf81": {
405
+ "model_module": "@jupyter-widgets/controls",
406
+ "model_name": "FloatProgressModel",
407
+ "model_module_version": "1.5.0",
408
+ "state": {
409
+ "_dom_classes": [],
410
+ "_model_module": "@jupyter-widgets/controls",
411
+ "_model_module_version": "1.5.0",
412
+ "_model_name": "FloatProgressModel",
413
+ "_view_count": null,
414
+ "_view_module": "@jupyter-widgets/controls",
415
+ "_view_module_version": "1.5.0",
416
+ "_view_name": "ProgressView",
417
+ "bar_style": "success",
418
+ "description": "",
419
+ "description_tooltip": null,
420
+ "layout": "IPY_MODEL_48abc963896a404886fbcf75b0b19bb9",
421
+ "max": 287,
422
+ "min": 0,
423
+ "orientation": "horizontal",
424
+ "style": "IPY_MODEL_87e3a17419334bf8b2448a8914f9d721",
425
+ "value": 287
426
+ }
427
+ },
428
+ "54719990ff1f40cb8fed06badb378d01": {
429
+ "model_module": "@jupyter-widgets/controls",
430
+ "model_name": "HTMLModel",
431
+ "model_module_version": "1.5.0",
432
+ "state": {
433
+ "_dom_classes": [],
434
+ "_model_module": "@jupyter-widgets/controls",
435
+ "_model_module_version": "1.5.0",
436
+ "_model_name": "HTMLModel",
437
+ "_view_count": null,
438
+ "_view_module": "@jupyter-widgets/controls",
439
+ "_view_module_version": "1.5.0",
440
+ "_view_name": "HTMLView",
441
+ "description": "",
442
+ "description_tooltip": null,
443
+ "layout": "IPY_MODEL_f8303a91b4084791971947ca45c6b459",
444
+ "placeholder": "​",
445
+ "style": "IPY_MODEL_a878599cc49347a896c793f3c45914e3",
446
+ "value": " 287/287 [00:00<00:00, 556.23 examples/s]"
447
+ }
448
+ },
449
+ "5d1be2eaa2c143bbbc35f7d0f33f64de": {
450
+ "model_module": "@jupyter-widgets/base",
451
+ "model_name": "LayoutModel",
452
+ "model_module_version": "1.2.0",
453
+ "state": {
454
+ "_model_module": "@jupyter-widgets/base",
455
+ "_model_module_version": "1.2.0",
456
+ "_model_name": "LayoutModel",
457
+ "_view_count": null,
458
+ "_view_module": "@jupyter-widgets/base",
459
+ "_view_module_version": "1.2.0",
460
+ "_view_name": "LayoutView",
461
+ "align_content": null,
462
+ "align_items": null,
463
+ "align_self": null,
464
+ "border": null,
465
+ "bottom": null,
466
+ "display": null,
467
+ "flex": null,
468
+ "flex_flow": null,
469
+ "grid_area": null,
470
+ "grid_auto_columns": null,
471
+ "grid_auto_flow": null,
472
+ "grid_auto_rows": null,
473
+ "grid_column": null,
474
+ "grid_gap": null,
475
+ "grid_row": null,
476
+ "grid_template_areas": null,
477
+ "grid_template_columns": null,
478
+ "grid_template_rows": null,
479
+ "height": null,
480
+ "justify_content": null,
481
+ "justify_items": null,
482
+ "left": null,
483
+ "margin": null,
484
+ "max_height": null,
485
+ "max_width": null,
486
+ "min_height": null,
487
+ "min_width": null,
488
+ "object_fit": null,
489
+ "object_position": null,
490
+ "order": null,
491
+ "overflow": null,
492
+ "overflow_x": null,
493
+ "overflow_y": null,
494
+ "padding": null,
495
+ "right": null,
496
+ "top": null,
497
+ "visibility": null,
498
+ "width": null
499
+ }
500
+ },
501
+ "002c9d35efa54fccb875a08e7059997f": {
502
+ "model_module": "@jupyter-widgets/base",
503
+ "model_name": "LayoutModel",
504
+ "model_module_version": "1.2.0",
505
+ "state": {
506
+ "_model_module": "@jupyter-widgets/base",
507
+ "_model_module_version": "1.2.0",
508
+ "_model_name": "LayoutModel",
509
+ "_view_count": null,
510
+ "_view_module": "@jupyter-widgets/base",
511
+ "_view_module_version": "1.2.0",
512
+ "_view_name": "LayoutView",
513
+ "align_content": null,
514
+ "align_items": null,
515
+ "align_self": null,
516
+ "border": null,
517
+ "bottom": null,
518
+ "display": null,
519
+ "flex": null,
520
+ "flex_flow": null,
521
+ "grid_area": null,
522
+ "grid_auto_columns": null,
523
+ "grid_auto_flow": null,
524
+ "grid_auto_rows": null,
525
+ "grid_column": null,
526
+ "grid_gap": null,
527
+ "grid_row": null,
528
+ "grid_template_areas": null,
529
+ "grid_template_columns": null,
530
+ "grid_template_rows": null,
531
+ "height": null,
532
+ "justify_content": null,
533
+ "justify_items": null,
534
+ "left": null,
535
+ "margin": null,
536
+ "max_height": null,
537
+ "max_width": null,
538
+ "min_height": null,
539
+ "min_width": null,
540
+ "object_fit": null,
541
+ "object_position": null,
542
+ "order": null,
543
+ "overflow": null,
544
+ "overflow_x": null,
545
+ "overflow_y": null,
546
+ "padding": null,
547
+ "right": null,
548
+ "top": null,
549
+ "visibility": null,
550
+ "width": null
551
+ }
552
+ },
553
+ "21dd8d7b7e5a4e27922ff1e3bec7745a": {
554
+ "model_module": "@jupyter-widgets/controls",
555
+ "model_name": "DescriptionStyleModel",
556
+ "model_module_version": "1.5.0",
557
+ "state": {
558
+ "_model_module": "@jupyter-widgets/controls",
559
+ "_model_module_version": "1.5.0",
560
+ "_model_name": "DescriptionStyleModel",
561
+ "_view_count": null,
562
+ "_view_module": "@jupyter-widgets/base",
563
+ "_view_module_version": "1.2.0",
564
+ "_view_name": "StyleView",
565
+ "description_width": ""
566
+ }
567
+ },
568
+ "48abc963896a404886fbcf75b0b19bb9": {
569
+ "model_module": "@jupyter-widgets/base",
570
+ "model_name": "LayoutModel",
571
+ "model_module_version": "1.2.0",
572
+ "state": {
573
+ "_model_module": "@jupyter-widgets/base",
574
+ "_model_module_version": "1.2.0",
575
+ "_model_name": "LayoutModel",
576
+ "_view_count": null,
577
+ "_view_module": "@jupyter-widgets/base",
578
+ "_view_module_version": "1.2.0",
579
+ "_view_name": "LayoutView",
580
+ "align_content": null,
581
+ "align_items": null,
582
+ "align_self": null,
583
+ "border": null,
584
+ "bottom": null,
585
+ "display": null,
586
+ "flex": null,
587
+ "flex_flow": null,
588
+ "grid_area": null,
589
+ "grid_auto_columns": null,
590
+ "grid_auto_flow": null,
591
+ "grid_auto_rows": null,
592
+ "grid_column": null,
593
+ "grid_gap": null,
594
+ "grid_row": null,
595
+ "grid_template_areas": null,
596
+ "grid_template_columns": null,
597
+ "grid_template_rows": null,
598
+ "height": null,
599
+ "justify_content": null,
600
+ "justify_items": null,
601
+ "left": null,
602
+ "margin": null,
603
+ "max_height": null,
604
+ "max_width": null,
605
+ "min_height": null,
606
+ "min_width": null,
607
+ "object_fit": null,
608
+ "object_position": null,
609
+ "order": null,
610
+ "overflow": null,
611
+ "overflow_x": null,
612
+ "overflow_y": null,
613
+ "padding": null,
614
+ "right": null,
615
+ "top": null,
616
+ "visibility": null,
617
+ "width": null
618
+ }
619
+ },
620
+ "87e3a17419334bf8b2448a8914f9d721": {
621
+ "model_module": "@jupyter-widgets/controls",
622
+ "model_name": "ProgressStyleModel",
623
+ "model_module_version": "1.5.0",
624
+ "state": {
625
+ "_model_module": "@jupyter-widgets/controls",
626
+ "_model_module_version": "1.5.0",
627
+ "_model_name": "ProgressStyleModel",
628
+ "_view_count": null,
629
+ "_view_module": "@jupyter-widgets/base",
630
+ "_view_module_version": "1.2.0",
631
+ "_view_name": "StyleView",
632
+ "bar_color": null,
633
+ "description_width": ""
634
+ }
635
+ },
636
+ "f8303a91b4084791971947ca45c6b459": {
637
+ "model_module": "@jupyter-widgets/base",
638
+ "model_name": "LayoutModel",
639
+ "model_module_version": "1.2.0",
640
+ "state": {
641
+ "_model_module": "@jupyter-widgets/base",
642
+ "_model_module_version": "1.2.0",
643
+ "_model_name": "LayoutModel",
644
+ "_view_count": null,
645
+ "_view_module": "@jupyter-widgets/base",
646
+ "_view_module_version": "1.2.0",
647
+ "_view_name": "LayoutView",
648
+ "align_content": null,
649
+ "align_items": null,
650
+ "align_self": null,
651
+ "border": null,
652
+ "bottom": null,
653
+ "display": null,
654
+ "flex": null,
655
+ "flex_flow": null,
656
+ "grid_area": null,
657
+ "grid_auto_columns": null,
658
+ "grid_auto_flow": null,
659
+ "grid_auto_rows": null,
660
+ "grid_column": null,
661
+ "grid_gap": null,
662
+ "grid_row": null,
663
+ "grid_template_areas": null,
664
+ "grid_template_columns": null,
665
+ "grid_template_rows": null,
666
+ "height": null,
667
+ "justify_content": null,
668
+ "justify_items": null,
669
+ "left": null,
670
+ "margin": null,
671
+ "max_height": null,
672
+ "max_width": null,
673
+ "min_height": null,
674
+ "min_width": null,
675
+ "object_fit": null,
676
+ "object_position": null,
677
+ "order": null,
678
+ "overflow": null,
679
+ "overflow_x": null,
680
+ "overflow_y": null,
681
+ "padding": null,
682
+ "right": null,
683
+ "top": null,
684
+ "visibility": null,
685
+ "width": null
686
+ }
687
+ },
688
+ "a878599cc49347a896c793f3c45914e3": {
689
+ "model_module": "@jupyter-widgets/controls",
690
+ "model_name": "DescriptionStyleModel",
691
+ "model_module_version": "1.5.0",
692
+ "state": {
693
+ "_model_module": "@jupyter-widgets/controls",
694
+ "_model_module_version": "1.5.0",
695
+ "_model_name": "DescriptionStyleModel",
696
+ "_view_count": null,
697
+ "_view_module": "@jupyter-widgets/base",
698
+ "_view_module_version": "1.2.0",
699
+ "_view_name": "StyleView",
700
+ "description_width": ""
701
+ }
702
+ }
703
+ }
704
+ }
705
+ },
706
+ "cells": [
707
+ {
708
+ "cell_type": "code",
709
+ "execution_count": 1,
710
+ "metadata": {
711
+ "colab": {
712
+ "base_uri": "https://localhost:8080/"
713
+ },
714
+ "id": "SRajt-tUH3ms",
715
+ "outputId": "f6077695-1508-4b60-b33a-7a29f37b4c75"
716
+ },
717
+ "outputs": [
718
+ {
719
+ "output_type": "stream",
720
+ "name": "stdout",
721
+ "text": [
722
+ "Requirement already satisfied: transformers in /usr/local/lib/python3.10/dist-packages (4.31.0)\n",
723
+ "Requirement already satisfied: datasets in /usr/local/lib/python3.10/dist-packages (2.14.4)\n",
724
+ "Requirement already satisfied: evaluate in /usr/local/lib/python3.10/dist-packages (0.4.0)\n",
725
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers) (3.12.2)\n",
726
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.16.4)\n",
727
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (1.23.5)\n",
728
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers) (23.1)\n",
729
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (6.0.1)\n",
730
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers) (2023.6.3)\n",
731
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers) (2.31.0)\n",
732
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.13.3)\n",
733
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers) (0.3.2)\n",
734
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers) (4.66.1)\n",
735
+ "Requirement already satisfied: pyarrow>=8.0.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (9.0.0)\n",
736
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /usr/local/lib/python3.10/dist-packages (from datasets) (0.3.7)\n",
737
+ "Requirement already satisfied: pandas in /usr/local/lib/python3.10/dist-packages (from datasets) (1.5.3)\n",
738
+ "Requirement already satisfied: xxhash in /usr/local/lib/python3.10/dist-packages (from datasets) (3.3.0)\n",
739
+ "Requirement already satisfied: multiprocess in /usr/local/lib/python3.10/dist-packages (from datasets) (0.70.15)\n",
740
+ "Requirement already satisfied: fsspec[http]>=2021.11.1 in /usr/local/lib/python3.10/dist-packages (from datasets) (2023.6.0)\n",
741
+ "Requirement already satisfied: aiohttp in /usr/local/lib/python3.10/dist-packages (from datasets) (3.8.5)\n",
742
+ "Requirement already satisfied: responses<0.19 in /usr/local/lib/python3.10/dist-packages (from evaluate) (0.18.0)\n",
743
+ "Requirement already satisfied: attrs>=17.3.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (23.1.0)\n",
744
+ "Requirement already satisfied: charset-normalizer<4.0,>=2.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (3.2.0)\n",
745
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (6.0.4)\n",
746
+ "Requirement already satisfied: async-timeout<5.0,>=4.0.0a3 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (4.0.3)\n",
747
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.9.2)\n",
748
+ "Requirement already satisfied: frozenlist>=1.1.1 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.4.0)\n",
749
+ "Requirement already satisfied: aiosignal>=1.1.2 in /usr/local/lib/python3.10/dist-packages (from aiohttp->datasets) (1.3.1)\n",
750
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers) (4.7.1)\n",
751
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (3.4)\n",
752
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2.0.4)\n",
753
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers) (2023.7.22)\n",
754
+ "Requirement already satisfied: python-dateutil>=2.8.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2.8.2)\n",
755
+ "Requirement already satisfied: pytz>=2020.1 in /usr/local/lib/python3.10/dist-packages (from pandas->datasets) (2023.3)\n",
756
+ "Requirement already satisfied: six>=1.5 in /usr/local/lib/python3.10/dist-packages (from python-dateutil>=2.8.1->pandas->datasets) (1.16.0)\n"
757
+ ]
758
+ }
759
+ ],
760
+ "source": [
761
+ "! pip install transformers datasets evaluate"
762
+ ]
763
+ },
764
+ {
765
+ "cell_type": "code",
766
+ "source": [
767
+ "from transformers import AutoTokenizer\n",
768
+ "\n",
769
+ "tokenizer = AutoTokenizer.from_pretrained(\"bert-base-uncased\")"
770
+ ],
771
+ "metadata": {
772
+ "id": "rjE6lHHJJdyv"
773
+ },
774
+ "execution_count": 2,
775
+ "outputs": []
776
+ },
777
+ {
778
+ "cell_type": "code",
779
+ "source": [
780
+ "import pandas as pd\n",
781
+ "from sklearn.model_selection import train_test_split\n",
782
+ "\n",
783
+ "data = pd.read_csv(\"ielts_writing_dataset_new.csv\")\n",
784
+ "\n",
785
+ "data.label = data.label.replace(1,0)\n",
786
+ "data.label = data.label.replace(3,0)\n",
787
+ "data.label = data.label.replace(3.5,0)\n",
788
+ "data.label = data.label.replace(4,0)\n",
789
+ "data.label = data.label.replace(4.5,0)\n",
790
+ "data.label = data.label.replace(5,0)\n",
791
+ "data.label = data.label.replace(5.5,1)\n",
792
+ "data.label = data.label.replace(6,1)\n",
793
+ "data.label = data.label.replace(6.5,1)\n",
794
+ "data.label = data.label.replace(7,1)\n",
795
+ "data.label = data.label.replace(7.5,1)\n",
796
+ "data.label = data.label.replace(8,2)\n",
797
+ "data.label = data.label.replace(8.5,2)\n",
798
+ "data.label = data.label.replace(9,2)\n",
799
+ "\n",
800
+ "data.label = data.label.astype(int)\n",
801
+ "\n",
802
+ "train, test = train_test_split(data, test_size=0.2)\n"
803
+ ],
804
+ "metadata": {
805
+ "id": "GpD5w5t2JihL"
806
+ },
807
+ "execution_count": 3,
808
+ "outputs": []
809
+ },
810
+ {
811
+ "cell_type": "code",
812
+ "source": [
813
+ "data[:10]"
814
+ ],
815
+ "metadata": {
816
+ "colab": {
817
+ "base_uri": "https://localhost:8080/",
818
+ "height": 363
819
+ },
820
+ "id": "Cos-ypQ7n7d9",
821
+ "outputId": "92caed9a-43e5-4a28-adf3-1727e3a15357"
822
+ },
823
+ "execution_count": 4,
824
+ "outputs": [
825
+ {
826
+ "output_type": "execute_result",
827
+ "data": {
828
+ "text/plain": [
829
+ " label text\n",
830
+ "0 1 Between 1995 and 2010, a study was conducted r...\n",
831
+ "1 1 Poverty represents a worldwide crisis. It is t...\n",
832
+ "2 0 The left chart shows the population change hap...\n",
833
+ "3 1 Human beings are facing many challenges nowada...\n",
834
+ "4 1 Information about the thousands of visits from...\n",
835
+ "5 1 Whether countries should only invest facilitie...\n",
836
+ "6 1 This graph depicts the changes in tourists vis...\n",
837
+ "7 1 Sports is an essential part to most of us , so...\n",
838
+ "8 2 The line graph illustrates the number of overs...\n",
839
+ "9 2 International sports events require the most w..."
840
+ ],
841
+ "text/html": [
842
+ "\n",
843
+ "\n",
844
+ " <div id=\"df-ee3fdca5-5d9d-44a1-9609-7b8c0c084882\">\n",
845
+ " <div class=\"colab-df-container\">\n",
846
+ " <div>\n",
847
+ "<style scoped>\n",
848
+ " .dataframe tbody tr th:only-of-type {\n",
849
+ " vertical-align: middle;\n",
850
+ " }\n",
851
+ "\n",
852
+ " .dataframe tbody tr th {\n",
853
+ " vertical-align: top;\n",
854
+ " }\n",
855
+ "\n",
856
+ " .dataframe thead th {\n",
857
+ " text-align: right;\n",
858
+ " }\n",
859
+ "</style>\n",
860
+ "<table border=\"1\" class=\"dataframe\">\n",
861
+ " <thead>\n",
862
+ " <tr style=\"text-align: right;\">\n",
863
+ " <th></th>\n",
864
+ " <th>label</th>\n",
865
+ " <th>text</th>\n",
866
+ " </tr>\n",
867
+ " </thead>\n",
868
+ " <tbody>\n",
869
+ " <tr>\n",
870
+ " <th>0</th>\n",
871
+ " <td>1</td>\n",
872
+ " <td>Between 1995 and 2010, a study was conducted r...</td>\n",
873
+ " </tr>\n",
874
+ " <tr>\n",
875
+ " <th>1</th>\n",
876
+ " <td>1</td>\n",
877
+ " <td>Poverty represents a worldwide crisis. It is t...</td>\n",
878
+ " </tr>\n",
879
+ " <tr>\n",
880
+ " <th>2</th>\n",
881
+ " <td>0</td>\n",
882
+ " <td>The left chart shows the population change hap...</td>\n",
883
+ " </tr>\n",
884
+ " <tr>\n",
885
+ " <th>3</th>\n",
886
+ " <td>1</td>\n",
887
+ " <td>Human beings are facing many challenges nowada...</td>\n",
888
+ " </tr>\n",
889
+ " <tr>\n",
890
+ " <th>4</th>\n",
891
+ " <td>1</td>\n",
892
+ " <td>Information about the thousands of visits from...</td>\n",
893
+ " </tr>\n",
894
+ " <tr>\n",
895
+ " <th>5</th>\n",
896
+ " <td>1</td>\n",
897
+ " <td>Whether countries should only invest facilitie...</td>\n",
898
+ " </tr>\n",
899
+ " <tr>\n",
900
+ " <th>6</th>\n",
901
+ " <td>1</td>\n",
902
+ " <td>This graph depicts the changes in tourists vis...</td>\n",
903
+ " </tr>\n",
904
+ " <tr>\n",
905
+ " <th>7</th>\n",
906
+ " <td>1</td>\n",
907
+ " <td>Sports is an essential part to most of us , so...</td>\n",
908
+ " </tr>\n",
909
+ " <tr>\n",
910
+ " <th>8</th>\n",
911
+ " <td>2</td>\n",
912
+ " <td>The line graph illustrates the number of overs...</td>\n",
913
+ " </tr>\n",
914
+ " <tr>\n",
915
+ " <th>9</th>\n",
916
+ " <td>2</td>\n",
917
+ " <td>International sports events require the most w...</td>\n",
918
+ " </tr>\n",
919
+ " </tbody>\n",
920
+ "</table>\n",
921
+ "</div>\n",
922
+ " <button class=\"colab-df-convert\" onclick=\"convertToInteractive('df-ee3fdca5-5d9d-44a1-9609-7b8c0c084882')\"\n",
923
+ " title=\"Convert this dataframe to an interactive table.\"\n",
924
+ " style=\"display:none;\">\n",
925
+ "\n",
926
+ " <svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
927
+ " width=\"24px\">\n",
928
+ " <path d=\"M0 0h24v24H0V0z\" fill=\"none\"/>\n",
929
+ " <path d=\"M18.56 5.44l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94zm-11 1L8.5 8.5l.94-2.06 2.06-.94-2.06-.94L8.5 2.5l-.94 2.06-2.06.94zm10 10l.94 2.06.94-2.06 2.06-.94-2.06-.94-.94-2.06-.94 2.06-2.06.94z\"/><path d=\"M17.41 7.96l-1.37-1.37c-.4-.4-.92-.59-1.43-.59-.52 0-1.04.2-1.43.59L10.3 9.45l-7.72 7.72c-.78.78-.78 2.05 0 2.83L4 21.41c.39.39.9.59 1.41.59.51 0 1.02-.2 1.41-.59l7.78-7.78 2.81-2.81c.8-.78.8-2.07 0-2.86zM5.41 20L4 18.59l7.72-7.72 1.47 1.35L5.41 20z\"/>\n",
930
+ " </svg>\n",
931
+ " </button>\n",
932
+ "\n",
933
+ "\n",
934
+ "\n",
935
+ " <div id=\"df-5159f8f9-e2c4-4afd-a1ba-4432c1f027a9\">\n",
936
+ " <button class=\"colab-df-quickchart\" onclick=\"quickchart('df-5159f8f9-e2c4-4afd-a1ba-4432c1f027a9')\"\n",
937
+ " title=\"Suggest charts.\"\n",
938
+ " style=\"display:none;\">\n",
939
+ "\n",
940
+ "<svg xmlns=\"http://www.w3.org/2000/svg\" height=\"24px\"viewBox=\"0 0 24 24\"\n",
941
+ " width=\"24px\">\n",
942
+ " <g>\n",
943
+ " <path d=\"M19 3H5c-1.1 0-2 .9-2 2v14c0 1.1.9 2 2 2h14c1.1 0 2-.9 2-2V5c0-1.1-.9-2-2-2zM9 17H7v-7h2v7zm4 0h-2V7h2v10zm4 0h-2v-4h2v4z\"/>\n",
944
+ " </g>\n",
945
+ "</svg>\n",
946
+ " </button>\n",
947
+ " </div>\n",
948
+ "\n",
949
+ "<style>\n",
950
+ " .colab-df-quickchart {\n",
951
+ " background-color: #E8F0FE;\n",
952
+ " border: none;\n",
953
+ " border-radius: 50%;\n",
954
+ " cursor: pointer;\n",
955
+ " display: none;\n",
956
+ " fill: #1967D2;\n",
957
+ " height: 32px;\n",
958
+ " padding: 0 0 0 0;\n",
959
+ " width: 32px;\n",
960
+ " }\n",
961
+ "\n",
962
+ " .colab-df-quickchart:hover {\n",
963
+ " background-color: #E2EBFA;\n",
964
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
965
+ " fill: #174EA6;\n",
966
+ " }\n",
967
+ "\n",
968
+ " [theme=dark] .colab-df-quickchart {\n",
969
+ " background-color: #3B4455;\n",
970
+ " fill: #D2E3FC;\n",
971
+ " }\n",
972
+ "\n",
973
+ " [theme=dark] .colab-df-quickchart:hover {\n",
974
+ " background-color: #434B5C;\n",
975
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
976
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
977
+ " fill: #FFFFFF;\n",
978
+ " }\n",
979
+ "</style>\n",
980
+ "\n",
981
+ " <script>\n",
982
+ " async function quickchart(key) {\n",
983
+ " const containerElement = document.querySelector('#' + key);\n",
984
+ " const charts = await google.colab.kernel.invokeFunction(\n",
985
+ " 'suggestCharts', [key], {});\n",
986
+ " }\n",
987
+ " </script>\n",
988
+ "\n",
989
+ "\n",
990
+ " <script>\n",
991
+ "\n",
992
+ "function displayQuickchartButton(domScope) {\n",
993
+ " let quickchartButtonEl =\n",
994
+ " domScope.querySelector('#df-5159f8f9-e2c4-4afd-a1ba-4432c1f027a9 button.colab-df-quickchart');\n",
995
+ " quickchartButtonEl.style.display =\n",
996
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
997
+ "}\n",
998
+ "\n",
999
+ " displayQuickchartButton(document);\n",
1000
+ " </script>\n",
1001
+ " <style>\n",
1002
+ " .colab-df-container {\n",
1003
+ " display:flex;\n",
1004
+ " flex-wrap:wrap;\n",
1005
+ " gap: 12px;\n",
1006
+ " }\n",
1007
+ "\n",
1008
+ " .colab-df-convert {\n",
1009
+ " background-color: #E8F0FE;\n",
1010
+ " border: none;\n",
1011
+ " border-radius: 50%;\n",
1012
+ " cursor: pointer;\n",
1013
+ " display: none;\n",
1014
+ " fill: #1967D2;\n",
1015
+ " height: 32px;\n",
1016
+ " padding: 0 0 0 0;\n",
1017
+ " width: 32px;\n",
1018
+ " }\n",
1019
+ "\n",
1020
+ " .colab-df-convert:hover {\n",
1021
+ " background-color: #E2EBFA;\n",
1022
+ " box-shadow: 0px 1px 2px rgba(60, 64, 67, 0.3), 0px 1px 3px 1px rgba(60, 64, 67, 0.15);\n",
1023
+ " fill: #174EA6;\n",
1024
+ " }\n",
1025
+ "\n",
1026
+ " [theme=dark] .colab-df-convert {\n",
1027
+ " background-color: #3B4455;\n",
1028
+ " fill: #D2E3FC;\n",
1029
+ " }\n",
1030
+ "\n",
1031
+ " [theme=dark] .colab-df-convert:hover {\n",
1032
+ " background-color: #434B5C;\n",
1033
+ " box-shadow: 0px 1px 3px 1px rgba(0, 0, 0, 0.15);\n",
1034
+ " filter: drop-shadow(0px 1px 2px rgba(0, 0, 0, 0.3));\n",
1035
+ " fill: #FFFFFF;\n",
1036
+ " }\n",
1037
+ " </style>\n",
1038
+ "\n",
1039
+ " <script>\n",
1040
+ " const buttonEl =\n",
1041
+ " document.querySelector('#df-ee3fdca5-5d9d-44a1-9609-7b8c0c084882 button.colab-df-convert');\n",
1042
+ " buttonEl.style.display =\n",
1043
+ " google.colab.kernel.accessAllowed ? 'block' : 'none';\n",
1044
+ "\n",
1045
+ " async function convertToInteractive(key) {\n",
1046
+ " const element = document.querySelector('#df-ee3fdca5-5d9d-44a1-9609-7b8c0c084882');\n",
1047
+ " const dataTable =\n",
1048
+ " await google.colab.kernel.invokeFunction('convertToInteractive',\n",
1049
+ " [key], {});\n",
1050
+ " if (!dataTable) return;\n",
1051
+ "\n",
1052
+ " const docLinkHtml = 'Like what you see? Visit the ' +\n",
1053
+ " '<a target=\"_blank\" href=https://colab.research.google.com/notebooks/data_table.ipynb>data table notebook</a>'\n",
1054
+ " + ' to learn more about interactive tables.';\n",
1055
+ " element.innerHTML = '';\n",
1056
+ " dataTable['output_type'] = 'display_data';\n",
1057
+ " await google.colab.output.renderOutput(dataTable, element);\n",
1058
+ " const docLink = document.createElement('div');\n",
1059
+ " docLink.innerHTML = docLinkHtml;\n",
1060
+ " element.appendChild(docLink);\n",
1061
+ " }\n",
1062
+ " </script>\n",
1063
+ " </div>\n",
1064
+ " </div>\n"
1065
+ ]
1066
+ },
1067
+ "metadata": {},
1068
+ "execution_count": 4
1069
+ }
1070
+ ]
1071
+ },
1072
+ {
1073
+ "cell_type": "code",
1074
+ "source": [
1075
+ "import datasets\n",
1076
+ "from datasets import Dataset, DatasetDict\n",
1077
+ "\n",
1078
+ "train = Dataset.from_pandas(train)\n",
1079
+ "test = Dataset.from_pandas(test)\n",
1080
+ "\n",
1081
+ "\n",
1082
+ "dataset = DatasetDict()\n",
1083
+ "\n",
1084
+ "dataset['train'] = train\n",
1085
+ "dataset['test'] = test\n",
1086
+ "dataset = dataset.remove_columns([\"__index_level_0__\"])\n",
1087
+ "dataset"
1088
+ ],
1089
+ "metadata": {
1090
+ "colab": {
1091
+ "base_uri": "https://localhost:8080/"
1092
+ },
1093
+ "id": "Mi7bkZ00L6ZB",
1094
+ "outputId": "3532f0d9-1961-44fc-ac50-bace0add6005"
1095
+ },
1096
+ "execution_count": 5,
1097
+ "outputs": [
1098
+ {
1099
+ "output_type": "execute_result",
1100
+ "data": {
1101
+ "text/plain": [
1102
+ "DatasetDict({\n",
1103
+ " train: Dataset({\n",
1104
+ " features: ['label', 'text'],\n",
1105
+ " num_rows: 1148\n",
1106
+ " })\n",
1107
+ " test: Dataset({\n",
1108
+ " features: ['label', 'text'],\n",
1109
+ " num_rows: 287\n",
1110
+ " })\n",
1111
+ "})"
1112
+ ]
1113
+ },
1114
+ "metadata": {},
1115
+ "execution_count": 5
1116
+ }
1117
+ ]
1118
+ },
1119
+ {
1120
+ "cell_type": "code",
1121
+ "source": [
1122
+ "dataset[\"test\"][0]"
1123
+ ],
1124
+ "metadata": {
1125
+ "colab": {
1126
+ "base_uri": "https://localhost:8080/"
1127
+ },
1128
+ "id": "QGCPOgv5MO1k",
1129
+ "outputId": "2d26d51c-2c62-4207-b0ac-8570aa89c798"
1130
+ },
1131
+ "execution_count": 6,
1132
+ "outputs": [
1133
+ {
1134
+ "output_type": "execute_result",
1135
+ "data": {
1136
+ "text/plain": [
1137
+ "{'label': 1,\n",
1138
+ " 'text': 'Everything has two sides and the globalization is not exception. Our first thoughts about this topic include the process of global “McDonaldisation” and, generally speaking, spreading across the whole Globe.Firstly, I would try to concentrate on the positive aspects of globalisation. As far as economy is concerned, like the Global Bank or IMF are always focused on developing the ‘Third World’ and helping poor people to combat their life obstacles (through loans and donations). Moreover, the world becomes an area of sharing thoughts (e.g. philosophical or economical doctrines), which become popular due to lack of barriers.However, disadvantages of globalization are also widely known. Some people insist that because of this process, the spirit of countries and nations rapidly disappears. The integrity, established years ago is on the verge of collapsing. Furthermore, there’s a strong lobby of communists who , that the globalization indicates an uncontrolled reign of capitalists and slave work of lower labour-class. We should never forget about the detrimental impact of global investments on the environment – the green house effect or soar rains are triggered by globalization.To sum up, globalization has both positive and negative influence on our everyday life. I can’t agree with the popular statement that we should try to avoid being affected by it. However, we must not forget about our surroundings and local communities. They have a great value which should last forever.'}"
1139
+ ]
1140
+ },
1141
+ "metadata": {},
1142
+ "execution_count": 6
1143
+ }
1144
+ ]
1145
+ },
1146
+ {
1147
+ "cell_type": "code",
1148
+ "source": [
1149
+ "def preprocess_function(examples):\n",
1150
+ " return tokenizer(examples[\"text\"], truncation=True)"
1151
+ ],
1152
+ "metadata": {
1153
+ "id": "z-Q57XYTMWsU"
1154
+ },
1155
+ "execution_count": 7,
1156
+ "outputs": []
1157
+ },
1158
+ {
1159
+ "cell_type": "code",
1160
+ "source": [
1161
+ "tokenized_dataset = dataset.map(preprocess_function, batched=True)"
1162
+ ],
1163
+ "metadata": {
1164
+ "colab": {
1165
+ "base_uri": "https://localhost:8080/",
1166
+ "height": 81,
1167
+ "referenced_widgets": [
1168
+ "e68b6e6997844bf788a057f9c7feedfb",
1169
+ "295e4080ccd64e48806a36b83e50ddfa",
1170
+ "c4025862f06b412cb99165b67ad7daae",
1171
+ "5ac369dab692489cb13cdb664c47fd96",
1172
+ "434aa0b7bd76440d9b9b64d8b53133d3",
1173
+ "9e2a1fea814f408ebb4d15db83b1130b",
1174
+ "4a2f178864244d68bd915ee57379251d",
1175
+ "7125f94d482a46999fd4dd3be1b3e87e",
1176
+ "96486cdef9714482a4ffa2aca1b3628b",
1177
+ "2364eb3ce5b345788902c5f9d316a00a",
1178
+ "52f799ea10d4403cb18e33ba80d739d3",
1179
+ "3e18acb6f1504f4dace716a96e8d90f4",
1180
+ "953e7d76140e4ed2ade688ccd5467a75",
1181
+ "3a70d75b4eb949598e7cb9430acfcf81",
1182
+ "54719990ff1f40cb8fed06badb378d01",
1183
+ "5d1be2eaa2c143bbbc35f7d0f33f64de",
1184
+ "002c9d35efa54fccb875a08e7059997f",
1185
+ "21dd8d7b7e5a4e27922ff1e3bec7745a",
1186
+ "48abc963896a404886fbcf75b0b19bb9",
1187
+ "87e3a17419334bf8b2448a8914f9d721",
1188
+ "f8303a91b4084791971947ca45c6b459",
1189
+ "a878599cc49347a896c793f3c45914e3"
1190
+ ]
1191
+ },
1192
+ "id": "0-Api6H3Mcqc",
1193
+ "outputId": "5fc02809-9cda-48da-9ac4-fe34f1742c22"
1194
+ },
1195
+ "execution_count": 8,
1196
+ "outputs": [
1197
+ {
1198
+ "output_type": "display_data",
1199
+ "data": {
1200
+ "text/plain": [
1201
+ "Map: 0%| | 0/1148 [00:00<?, ? examples/s]"
1202
+ ],
1203
+ "application/vnd.jupyter.widget-view+json": {
1204
+ "version_major": 2,
1205
+ "version_minor": 0,
1206
+ "model_id": "e68b6e6997844bf788a057f9c7feedfb"
1207
+ }
1208
+ },
1209
+ "metadata": {}
1210
+ },
1211
+ {
1212
+ "output_type": "display_data",
1213
+ "data": {
1214
+ "text/plain": [
1215
+ "Map: 0%| | 0/287 [00:00<?, ? examples/s]"
1216
+ ],
1217
+ "application/vnd.jupyter.widget-view+json": {
1218
+ "version_major": 2,
1219
+ "version_minor": 0,
1220
+ "model_id": "3e18acb6f1504f4dace716a96e8d90f4"
1221
+ }
1222
+ },
1223
+ "metadata": {}
1224
+ }
1225
+ ]
1226
+ },
1227
+ {
1228
+ "cell_type": "code",
1229
+ "source": [
1230
+ "from transformers import DataCollatorWithPadding\n",
1231
+ "\n",
1232
+ "data_collator = DataCollatorWithPadding(tokenizer=tokenizer)"
1233
+ ],
1234
+ "metadata": {
1235
+ "id": "CMgTijF_MkZ-"
1236
+ },
1237
+ "execution_count": 9,
1238
+ "outputs": []
1239
+ },
1240
+ {
1241
+ "cell_type": "code",
1242
+ "source": [
1243
+ "tokenized_dataset['train']"
1244
+ ],
1245
+ "metadata": {
1246
+ "colab": {
1247
+ "base_uri": "https://localhost:8080/"
1248
+ },
1249
+ "id": "pFa_-NPcXQM3",
1250
+ "outputId": "c1379cbf-80ca-433e-86c8-a2a337e10b1b"
1251
+ },
1252
+ "execution_count": 10,
1253
+ "outputs": [
1254
+ {
1255
+ "output_type": "execute_result",
1256
+ "data": {
1257
+ "text/plain": [
1258
+ "Dataset({\n",
1259
+ " features: ['label', 'text', 'input_ids', 'token_type_ids', 'attention_mask'],\n",
1260
+ " num_rows: 1148\n",
1261
+ "})"
1262
+ ]
1263
+ },
1264
+ "metadata": {},
1265
+ "execution_count": 10
1266
+ }
1267
+ ]
1268
+ },
1269
+ {
1270
+ "cell_type": "code",
1271
+ "source": [
1272
+ "import evaluate\n",
1273
+ "\n",
1274
+ "accuracy = evaluate.load(\"accuracy\")"
1275
+ ],
1276
+ "metadata": {
1277
+ "id": "zHjByQbVMobe"
1278
+ },
1279
+ "execution_count": 11,
1280
+ "outputs": []
1281
+ },
1282
+ {
1283
+ "cell_type": "code",
1284
+ "source": [
1285
+ "import numpy as np\n",
1286
+ "\n",
1287
+ "\n",
1288
+ "def compute_metrics(eval_pred):\n",
1289
+ " predictions, labels = eval_pred\n",
1290
+ " predictions = np.argmax(predictions, axis=1)\n",
1291
+ " return accuracy.compute(predictions=predictions, references=labels)"
1292
+ ],
1293
+ "metadata": {
1294
+ "id": "GQJysWFsMsyR"
1295
+ },
1296
+ "execution_count": 12,
1297
+ "outputs": []
1298
+ },
1299
+ {
1300
+ "cell_type": "code",
1301
+ "source": [
1302
+ "# id2label = {0: '1', 1:'3', 2:'3.5', 3:'4', 4:'4.5',5:'5', 6:'5.5', 7:'6', 8:'6.5',9:'7',10:'7.5',11:'8',12:'8.5',13:'9'}\n",
1303
+ "# label2id = {'1':0,'3':1,'3.5':2,'4':3,'4.5':4,'5':5,'5.5':6,'6':7,'6.5':8,'7':9,'7.5':10,'8':11,'8.5':12,'9':13}\n",
1304
+ "id2label = {0:\"Bad\",1:\"Acceptable\",2:\"Excellent\"}\n",
1305
+ "label2id = {\"Bad\":0,\"Acceptable\":1,\"Excellent\":2}\n",
1306
+ "\n"
1307
+ ],
1308
+ "metadata": {
1309
+ "id": "HgDWrzrvMvDW"
1310
+ },
1311
+ "execution_count": 13,
1312
+ "outputs": []
1313
+ },
1314
+ {
1315
+ "cell_type": "code",
1316
+ "source": [
1317
+ "from transformers import BertForSequenceClassification, TrainingArguments, Trainer\n",
1318
+ "\n",
1319
+ "model = BertForSequenceClassification.from_pretrained(\n",
1320
+ " \"bert-base-uncased\",num_labels=3, id2label=id2label, label2id=label2id,\n",
1321
+ ")"
1322
+ ],
1323
+ "metadata": {
1324
+ "colab": {
1325
+ "base_uri": "https://localhost:8080/"
1326
+ },
1327
+ "id": "7xaZqPOzOVJP",
1328
+ "outputId": "199584fb-36b3-4906-c3ac-614f9b38950e"
1329
+ },
1330
+ "execution_count": 14,
1331
+ "outputs": [
1332
+ {
1333
+ "output_type": "stream",
1334
+ "name": "stderr",
1335
+ "text": [
1336
+ "Some weights of BertForSequenceClassification were not initialized from the model checkpoint at bert-base-uncased and are newly initialized: ['classifier.weight', 'classifier.bias']\n",
1337
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
1338
+ ]
1339
+ }
1340
+ ]
1341
+ },
1342
+ {
1343
+ "cell_type": "code",
1344
+ "source": [
1345
+ "! pip install transformers[torch]"
1346
+ ],
1347
+ "metadata": {
1348
+ "colab": {
1349
+ "base_uri": "https://localhost:8080/"
1350
+ },
1351
+ "id": "s7bor4hUOq4q",
1352
+ "outputId": "632f838a-b986-43e4-d901-4d5398912fb6"
1353
+ },
1354
+ "execution_count": 15,
1355
+ "outputs": [
1356
+ {
1357
+ "output_type": "stream",
1358
+ "name": "stdout",
1359
+ "text": [
1360
+ "Requirement already satisfied: transformers[torch] in /usr/local/lib/python3.10/dist-packages (4.31.0)\n",
1361
+ "Requirement already satisfied: filelock in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (3.12.2)\n",
1362
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.14.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.16.4)\n",
1363
+ "Requirement already satisfied: numpy>=1.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (1.23.5)\n",
1364
+ "Requirement already satisfied: packaging>=20.0 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (23.1)\n",
1365
+ "Requirement already satisfied: pyyaml>=5.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (6.0.1)\n",
1366
+ "Requirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2023.6.3)\n",
1367
+ "Requirement already satisfied: requests in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.31.0)\n",
1368
+ "Requirement already satisfied: tokenizers!=0.11.3,<0.14,>=0.11.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.13.3)\n",
1369
+ "Requirement already satisfied: safetensors>=0.3.1 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.3.2)\n",
1370
+ "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (4.66.1)\n",
1371
+ "Requirement already satisfied: torch!=1.12.0,>=1.9 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (2.0.1+cu118)\n",
1372
+ "Requirement already satisfied: accelerate>=0.20.3 in /usr/local/lib/python3.10/dist-packages (from transformers[torch]) (0.21.0)\n",
1373
+ "Requirement already satisfied: psutil in /usr/local/lib/python3.10/dist-packages (from accelerate>=0.20.3->transformers[torch]) (5.9.5)\n",
1374
+ "Requirement already satisfied: fsspec in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers[torch]) (2023.6.0)\n",
1375
+ "Requirement already satisfied: typing-extensions>=3.7.4.3 in /usr/local/lib/python3.10/dist-packages (from huggingface-hub<1.0,>=0.14.1->transformers[torch]) (4.7.1)\n",
1376
+ "Requirement already satisfied: sympy in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (1.12)\n",
1377
+ "Requirement already satisfied: networkx in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (3.1)\n",
1378
+ "Requirement already satisfied: jinja2 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (3.1.2)\n",
1379
+ "Requirement already satisfied: triton==2.0.0 in /usr/local/lib/python3.10/dist-packages (from torch!=1.12.0,>=1.9->transformers[torch]) (2.0.0)\n",
1380
+ "Requirement already satisfied: cmake in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch!=1.12.0,>=1.9->transformers[torch]) (3.27.2)\n",
1381
+ "Requirement already satisfied: lit in /usr/local/lib/python3.10/dist-packages (from triton==2.0.0->torch!=1.12.0,>=1.9->transformers[torch]) (16.0.6)\n",
1382
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.2.0)\n",
1383
+ "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (3.4)\n",
1384
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2.0.4)\n",
1385
+ "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests->transformers[torch]) (2023.7.22)\n",
1386
+ "Requirement already satisfied: MarkupSafe>=2.0 in /usr/local/lib/python3.10/dist-packages (from jinja2->torch!=1.12.0,>=1.9->transformers[torch]) (2.1.3)\n",
1387
+ "Requirement already satisfied: mpmath>=0.19 in /usr/local/lib/python3.10/dist-packages (from sympy->torch!=1.12.0,>=1.9->transformers[torch]) (1.3.0)\n"
1388
+ ]
1389
+ }
1390
+ ]
1391
+ },
1392
+ {
1393
+ "cell_type": "code",
1394
+ "source": [
1395
+ "from torch import nn\n",
1396
+ "\n",
1397
+ "class ClassificationTrainer(Trainer):\n",
1398
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
1399
+ " labels = inputs.get(\"label\")\n",
1400
+ " outputs = model(**inputs)\n",
1401
+ " outputs = outputs.unsqueeze(1)\n",
1402
+ " logits = outputs.get('logits')\n",
1403
+ " loss_fct = nn.CrossEntropyLoss()\n",
1404
+ " loss = loss_fct(logits.squeeze(), labels.squeeze())\n",
1405
+ " return (loss, outputs) if return_outputs else loss"
1406
+ ],
1407
+ "metadata": {
1408
+ "id": "KQ2UskBkU4D9"
1409
+ },
1410
+ "execution_count": 16,
1411
+ "outputs": []
1412
+ },
1413
+ {
1414
+ "cell_type": "code",
1415
+ "source": [
1416
+ "training_args = TrainingArguments(\n",
1417
+ " output_dir=\"essayl0\",\n",
1418
+ " learning_rate=2e-5,\n",
1419
+ " per_device_train_batch_size=16,\n",
1420
+ " per_device_eval_batch_size=16,\n",
1421
+ " num_train_epochs=15,\n",
1422
+ " weight_decay=0.01,\n",
1423
+ " evaluation_strategy=\"epoch\",\n",
1424
+ " save_strategy=\"epoch\",\n",
1425
+ " load_best_model_at_end=True,\n",
1426
+ ")\n",
1427
+ "\n",
1428
+ "trainer = Trainer(\n",
1429
+ " model=model,\n",
1430
+ " args=training_args,\n",
1431
+ " train_dataset=tokenized_dataset[\"train\"],\n",
1432
+ " eval_dataset=tokenized_dataset[\"test\"],\n",
1433
+ " tokenizer=tokenizer,\n",
1434
+ " data_collator=data_collator,\n",
1435
+ " compute_metrics=compute_metrics,\n",
1436
+ ")\n",
1437
+ "\n",
1438
+ "trainer.train()"
1439
+ ],
1440
+ "metadata": {
1441
+ "colab": {
1442
+ "base_uri": "https://localhost:8080/",
1443
+ "height": 656
1444
+ },
1445
+ "id": "BwyTlAy0OdRS",
1446
+ "outputId": "dca2b59c-a8d7-40fb-ded3-d0a3685949d7"
1447
+ },
1448
+ "execution_count": 17,
1449
+ "outputs": [
1450
+ {
1451
+ "output_type": "stream",
1452
+ "name": "stderr",
1453
+ "text": [
1454
+ "/usr/local/lib/python3.10/dist-packages/transformers/optimization.py:411: FutureWarning: This implementation of AdamW is deprecated and will be removed in a future version. Use the PyTorch implementation torch.optim.AdamW instead, or set `no_deprecation_warning=True` to disable this warning\n",
1455
+ " warnings.warn(\n",
1456
+ "You're using a BertTokenizerFast tokenizer. Please note that with a fast tokenizer, using the `__call__` method is faster than using a method to encode the text followed by a call to the `pad` method to get a padded encoding.\n"
1457
+ ]
1458
+ },
1459
+ {
1460
+ "output_type": "display_data",
1461
+ "data": {
1462
+ "text/plain": [
1463
+ "<IPython.core.display.HTML object>"
1464
+ ],
1465
+ "text/html": [
1466
+ "\n",
1467
+ " <div>\n",
1468
+ " \n",
1469
+ " <progress value='1080' max='1080' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1470
+ " [1080/1080 29:09, Epoch 15/15]\n",
1471
+ " </div>\n",
1472
+ " <table border=\"1\" class=\"dataframe\">\n",
1473
+ " <thead>\n",
1474
+ " <tr style=\"text-align: left;\">\n",
1475
+ " <th>Epoch</th>\n",
1476
+ " <th>Training Loss</th>\n",
1477
+ " <th>Validation Loss</th>\n",
1478
+ " <th>Accuracy</th>\n",
1479
+ " </tr>\n",
1480
+ " </thead>\n",
1481
+ " <tbody>\n",
1482
+ " <tr>\n",
1483
+ " <td>1</td>\n",
1484
+ " <td>No log</td>\n",
1485
+ " <td>0.601437</td>\n",
1486
+ " <td>0.752613</td>\n",
1487
+ " </tr>\n",
1488
+ " <tr>\n",
1489
+ " <td>2</td>\n",
1490
+ " <td>No log</td>\n",
1491
+ " <td>0.444218</td>\n",
1492
+ " <td>0.860627</td>\n",
1493
+ " </tr>\n",
1494
+ " <tr>\n",
1495
+ " <td>3</td>\n",
1496
+ " <td>No log</td>\n",
1497
+ " <td>0.510611</td>\n",
1498
+ " <td>0.815331</td>\n",
1499
+ " </tr>\n",
1500
+ " <tr>\n",
1501
+ " <td>4</td>\n",
1502
+ " <td>No log</td>\n",
1503
+ " <td>0.723215</td>\n",
1504
+ " <td>0.766551</td>\n",
1505
+ " </tr>\n",
1506
+ " <tr>\n",
1507
+ " <td>5</td>\n",
1508
+ " <td>No log</td>\n",
1509
+ " <td>0.556284</td>\n",
1510
+ " <td>0.850174</td>\n",
1511
+ " </tr>\n",
1512
+ " <tr>\n",
1513
+ " <td>6</td>\n",
1514
+ " <td>No log</td>\n",
1515
+ " <td>0.783423</td>\n",
1516
+ " <td>0.794425</td>\n",
1517
+ " </tr>\n",
1518
+ " <tr>\n",
1519
+ " <td>7</td>\n",
1520
+ " <td>0.275800</td>\n",
1521
+ " <td>0.735923</td>\n",
1522
+ " <td>0.850174</td>\n",
1523
+ " </tr>\n",
1524
+ " <tr>\n",
1525
+ " <td>8</td>\n",
1526
+ " <td>0.275800</td>\n",
1527
+ " <td>0.654791</td>\n",
1528
+ " <td>0.878049</td>\n",
1529
+ " </tr>\n",
1530
+ " <tr>\n",
1531
+ " <td>9</td>\n",
1532
+ " <td>0.275800</td>\n",
1533
+ " <td>0.633503</td>\n",
1534
+ " <td>0.888502</td>\n",
1535
+ " </tr>\n",
1536
+ " <tr>\n",
1537
+ " <td>10</td>\n",
1538
+ " <td>0.275800</td>\n",
1539
+ " <td>1.105006</td>\n",
1540
+ " <td>0.783972</td>\n",
1541
+ " </tr>\n",
1542
+ " <tr>\n",
1543
+ " <td>11</td>\n",
1544
+ " <td>0.275800</td>\n",
1545
+ " <td>0.710119</td>\n",
1546
+ " <td>0.878049</td>\n",
1547
+ " </tr>\n",
1548
+ " <tr>\n",
1549
+ " <td>12</td>\n",
1550
+ " <td>0.275800</td>\n",
1551
+ " <td>0.792314</td>\n",
1552
+ " <td>0.839721</td>\n",
1553
+ " </tr>\n",
1554
+ " <tr>\n",
1555
+ " <td>13</td>\n",
1556
+ " <td>0.275800</td>\n",
1557
+ " <td>0.863435</td>\n",
1558
+ " <td>0.843206</td>\n",
1559
+ " </tr>\n",
1560
+ " <tr>\n",
1561
+ " <td>14</td>\n",
1562
+ " <td>0.018500</td>\n",
1563
+ " <td>0.834555</td>\n",
1564
+ " <td>0.843206</td>\n",
1565
+ " </tr>\n",
1566
+ " <tr>\n",
1567
+ " <td>15</td>\n",
1568
+ " <td>0.018500</td>\n",
1569
+ " <td>0.864810</td>\n",
1570
+ " <td>0.832753</td>\n",
1571
+ " </tr>\n",
1572
+ " </tbody>\n",
1573
+ "</table><p>"
1574
+ ]
1575
+ },
1576
+ "metadata": {}
1577
+ },
1578
+ {
1579
+ "output_type": "execute_result",
1580
+ "data": {
1581
+ "text/plain": [
1582
+ "TrainOutput(global_step=1080, training_loss=0.13700703542541576, metrics={'train_runtime': 1752.9066, 'train_samples_per_second': 9.824, 'train_steps_per_second': 0.616, 'total_flos': 4194210824632584.0, 'train_loss': 0.13700703542541576, 'epoch': 15.0})"
1583
+ ]
1584
+ },
1585
+ "metadata": {},
1586
+ "execution_count": 17
1587
+ }
1588
+ ]
1589
+ },
1590
+ {
1591
+ "cell_type": "code",
1592
+ "source": [
1593
+ "!zip -r /content/checkpoint.zip /content/essayl0/checkpoint-1080/"
1594
+ ],
1595
+ "metadata": {
1596
+ "colab": {
1597
+ "base_uri": "https://localhost:8080/"
1598
+ },
1599
+ "id": "s6wG4purBmfX",
1600
+ "outputId": "3363587c-a6e3-4a40-db80-73d6eaf26cf7"
1601
+ },
1602
+ "execution_count": 18,
1603
+ "outputs": [
1604
+ {
1605
+ "output_type": "stream",
1606
+ "name": "stdout",
1607
+ "text": [
1608
+ " adding: content/essayl0/checkpoint-1080/ (stored 0%)\n",
1609
+ " adding: content/essayl0/checkpoint-1080/special_tokens_map.json (deflated 42%)\n",
1610
+ " adding: content/essayl0/checkpoint-1080/rng_state.pth (deflated 28%)\n",
1611
+ " adding: content/essayl0/checkpoint-1080/vocab.txt (deflated 53%)\n",
1612
+ " adding: content/essayl0/checkpoint-1080/tokenizer.json (deflated 71%)\n",
1613
+ " adding: content/essayl0/checkpoint-1080/config.json (deflated 50%)\n",
1614
+ " adding: content/essayl0/checkpoint-1080/trainer_state.json (deflated 78%)\n",
1615
+ " adding: content/essayl0/checkpoint-1080/pytorch_model.bin (deflated 7%)\n",
1616
+ " adding: content/essayl0/checkpoint-1080/optimizer.pt (deflated 21%)\n",
1617
+ " adding: content/essayl0/checkpoint-1080/training_args.bin (deflated 48%)\n",
1618
+ " adding: content/essayl0/checkpoint-1080/tokenizer_config.json (deflated 43%)\n",
1619
+ " adding: content/essayl0/checkpoint-1080/scheduler.pt (deflated 49%)\n"
1620
+ ]
1621
+ }
1622
+ ]
1623
+ }
1624
+ ]
1625
+ }