chrisvoncsefalvay commited on
Commit
5e1c670
·
verified ·
1 Parent(s): 3c46ad8

Training in progress, step 5000

Browse files
Files changed (47) hide show
  1. .amlignore +6 -0
  2. .amlignore.amltmp +6 -0
  3. .gitattributes +1 -0
  4. .gitignore +886 -0
  5. config.json +1 -1
  6. data/.amlignore +6 -0
  7. data/.amlignore.amltmp +6 -0
  8. data/.gitkeep +0 -0
  9. data/custom_vocab.txt +0 -0
  10. model.safetensors +2 -2
  11. notebooks/.amlignore +6 -0
  12. notebooks/.amlignore.amltmp +6 -0
  13. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-27-56Z.ipynb +744 -0
  14. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-52-4Z.ipynb +788 -0
  15. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-13-2-30Z.ipynb +1147 -0
  16. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-15-7-36Z.ipynb +1452 -0
  17. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-16-26-9Z.ipynb +1246 -0
  18. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-20-56-58Z.ipynb +993 -0
  19. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-23-54-39Z.ipynb +692 -0
  20. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-3-12-1Z.ipynb +1053 -0
  21. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-4-13-53Z.ipynb +0 -0
  22. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-14-26-30Z.ipynb +739 -0
  23. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-16-5-15Z.ipynb +729 -0
  24. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-17-44-52Z.ipynb +739 -0
  25. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-3-40-27Z.ipynb +1001 -0
  26. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-4-40-54Z.ipynb +1073 -0
  27. notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-30-21-44-8Z.ipynb +671 -0
  28. notebooks/.ipynb_aml_checkpoints/microsample_model_comparison-checkpoint2024-0-31-14-6-22Z.ipynb +0 -0
  29. notebooks/DAEDRA-Copy1.ipynb +1634 -0
  30. notebooks/DAEDRA.ipynb +671 -0
  31. notebooks/DAEDRA.yml +0 -0
  32. notebooks/Dataset preparation.ipynb +524 -0
  33. notebooks/Untitled.ipynb +33 -0
  34. notebooks/comparisons.csv +3 -0
  35. notebooks/daedra.ipynb.amltmp +671 -0
  36. notebooks/daedra.py +134 -0
  37. notebooks/daedra.py.amltmp +134 -0
  38. notebooks/daedra_final_training.py.amltmp +136 -0
  39. notebooks/emissions.csv +3 -0
  40. notebooks/emissions.csv.amltmp +3 -0
  41. notebooks/microsample_model_comparison.ipynb +0 -0
  42. notebooks/tokenizer.json +0 -0
  43. notebooks/wandb/.amlignore +6 -0
  44. notebooks/wandb/.amlignore.amltmp +6 -0
  45. paper/.gitkeep +0 -0
  46. tokenizer.json +6 -1
  47. training_args.bin +1 -1
.amlignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
.amlignore.amltmp ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ notebooks/comparisons.csv filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,886 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### JetBrains template
2
+ # Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
3
+ # Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
4
+
5
+ # User-specific stuff
6
+ .idea/**/workspace.xml
7
+ .idea/**/tasks.xml
8
+ .idea/**/usage.statistics.xml
9
+ .idea/**/dictionaries
10
+ .idea/**/shelf
11
+
12
+ # Data folder
13
+ data/*.csv
14
+
15
+ # AWS User-specific
16
+ .idea/**/aws.xml
17
+
18
+ # Generated files
19
+ .idea/**/contentModel.xml
20
+
21
+ # Sensitive or high-churn files
22
+ .idea/**/dataSources/
23
+ .idea/**/dataSources.ids
24
+ .idea/**/dataSources.local.xml
25
+ .idea/**/sqlDataSources.xml
26
+ .idea/**/dynamic.xml
27
+ .idea/**/uiDesigner.xml
28
+ .idea/**/dbnavigator.xml
29
+
30
+ # Gradle
31
+ .idea/**/gradle.xml
32
+ .idea/**/libraries
33
+
34
+ # Gradle and Maven with auto-import
35
+ # When using Gradle or Maven with auto-import, you should exclude module files,
36
+ # since they will be recreated, and may cause churn. Uncomment if using
37
+ # auto-import.
38
+ # .idea/artifacts
39
+ # .idea/compiler.xml
40
+ # .idea/jarRepositories.xml
41
+ # .idea/modules.xml
42
+ # .idea/*.iml
43
+ # .idea/modules
44
+ # *.iml
45
+ # *.ipr
46
+
47
+ # CMake
48
+ cmake-build-*/
49
+
50
+ # Mongo Explorer plugin
51
+ .idea/**/mongoSettings.xml
52
+
53
+ # File-based project format
54
+ *.iws
55
+
56
+ # IntelliJ
57
+ out/
58
+
59
+ # mpeltonen/sbt-idea plugin
60
+ .idea_modules/
61
+
62
+ # JIRA plugin
63
+ atlassian-ide-plugin.xml
64
+
65
+ # Cursive Clojure plugin
66
+ .idea/replstate.xml
67
+
68
+ # SonarLint plugin
69
+ .idea/sonarlint/
70
+
71
+ # Crashlytics plugin (for Android Studio and IntelliJ)
72
+ com_crashlytics_export_strings.xml
73
+ crashlytics.properties
74
+ crashlytics-build.properties
75
+ fabric.properties
76
+
77
+ # Editor-based Rest Client
78
+ .idea/httpRequests
79
+
80
+ # Android studio 3.1+ serialized cache file
81
+ .idea/caches/build_file_checksums.ser
82
+
83
+ ### OSX template
84
+ # General
85
+ .DS_Store
86
+ .AppleDouble
87
+ .LSOverride
88
+
89
+ # Icon must end with two \r
90
+ Icon
91
+
92
+ # Thumbnails
93
+ ._*
94
+
95
+ # Files that might appear in the root of a volume
96
+ .DocumentRevisions-V100
97
+ .fseventsd
98
+ .Spotlight-V100
99
+ .TemporaryItems
100
+ .Trashes
101
+ .VolumeIcon.icns
102
+ .com.apple.timemachine.donotpresent
103
+
104
+ # Directories potentially created on remote AFP share
105
+ .AppleDB
106
+ .AppleDesktop
107
+ Network Trash Folder
108
+ Temporary Items
109
+ .apdisk
110
+
111
+ ### TeX template
112
+ ## Core latex/pdflatex auxiliary files:
113
+ *.aux
114
+ *.lof
115
+ *.log
116
+ *.lot
117
+ *.fls
118
+ *.out
119
+ *.toc
120
+ *.fmt
121
+ *.fot
122
+ *.cb
123
+ *.cb2
124
+ .*.lb
125
+
126
+ ## Intermediate documents:
127
+ *.dvi
128
+ *.xdv
129
+ *-converted-to.*
130
+ # these rules might exclude image files for figures etc.
131
+ # *.ps
132
+ # *.eps
133
+ # *.pdf
134
+
135
+ ## Generated if empty string is given at "Please type another file name for output:"
136
+ .pdf
137
+
138
+ ## Bibliography auxiliary files (bibtex/biblatex/biber):
139
+ *.bbl
140
+ *.bcf
141
+ *.blg
142
+ *-blx.aux
143
+ *-blx.bib
144
+ *.run.xml
145
+
146
+ ## Build tool auxiliary files:
147
+ *.fdb_latexmk
148
+ *.synctex
149
+ *.synctex(busy)
150
+ *.synctex.gz
151
+ *.synctex.gz(busy)
152
+ *.pdfsync
153
+
154
+ ## Build tool directories for auxiliary files
155
+ # latexrun
156
+ latex.out/
157
+
158
+ ## Auxiliary and intermediate files from other packages:
159
+ # algorithms
160
+ *.alg
161
+ *.loa
162
+
163
+ # achemso
164
+ acs-*.bib
165
+
166
+ # amsthm
167
+ *.thm
168
+
169
+ # beamer
170
+ *.nav
171
+ *.pre
172
+ *.snm
173
+ *.vrb
174
+
175
+ # changes
176
+ *.soc
177
+
178
+ # comment
179
+ *.cut
180
+
181
+ # cprotect
182
+ *.cpt
183
+
184
+ # elsarticle (documentclass of Elsevier journals)
185
+ *.spl
186
+
187
+ # endnotes
188
+ *.ent
189
+
190
+ *.lox
191
+
192
+ # feynmf/feynmp
193
+ *.mf
194
+ *.mp
195
+ *.t[1-9]
196
+ *.t[1-9][0-9]
197
+ *.tfm
198
+
199
+ #(r)(e)ledmac/(r)(e)ledpar
200
+ *.end
201
+ *.?end
202
+ *.[1-9]
203
+ *.[1-9][0-9]
204
+ *.[1-9][0-9][0-9]
205
+ *.[1-9]R
206
+ *.[1-9][0-9]R
207
+ *.[1-9][0-9][0-9]R
208
+ *.eledsec[1-9]
209
+ *.eledsec[1-9]R
210
+ *.eledsec[1-9][0-9]
211
+ *.eledsec[1-9][0-9]R
212
+ *.eledsec[1-9][0-9][0-9]
213
+ *.eledsec[1-9][0-9][0-9]R
214
+
215
+ # glossaries
216
+ *.acn
217
+ *.acr
218
+ *.glg
219
+ *.glo
220
+ *.gls
221
+ *.glsdefs
222
+ *.lzo
223
+ *.lzs
224
+ *.slg
225
+ *.slo
226
+ *.sls
227
+
228
+ # uncomment this for glossaries-extra (will ignore makeindex's style files!)
229
+ # *.ist
230
+
231
+ # gnuplot
232
+ *.gnuplot
233
+ *.table
234
+
235
+ # gnuplottex
236
+ *-gnuplottex-*
237
+
238
+ # gregoriotex
239
+ *.gaux
240
+ *.glog
241
+ *.gtex
242
+
243
+ # htlatex
244
+ *.4ct
245
+ *.4tc
246
+ *.idv
247
+ *.lg
248
+ *.trc
249
+ *.xref
250
+
251
+ # hyperref
252
+ *.brf
253
+
254
+ # knitr
255
+ *-concordance.tex
256
+ # *.tikz
257
+ *-tikzDictionary
258
+
259
+ # listings
260
+ *.lol
261
+
262
+ # luatexja-ruby
263
+ *.ltjruby
264
+
265
+ # makeidx
266
+ *.idx
267
+ *.ilg
268
+ *.ind
269
+
270
+ # minitoc
271
+ *.maf
272
+ *.mlf
273
+ *.mlt
274
+ *.mtc[0-9]*
275
+ *.slf[0-9]*
276
+ *.slt[0-9]*
277
+ *.stc[0-9]*
278
+
279
+ # minted
280
+ _minted*
281
+ *.pyg
282
+
283
+ # morewrites
284
+ *.mw
285
+
286
+ # newpax
287
+ *.newpax
288
+
289
+ # nomencl
290
+ *.nlg
291
+ *.nlo
292
+ *.nls
293
+
294
+ # pax
295
+ *.pax
296
+
297
+ # pdfpcnotes
298
+ *.pdfpc
299
+
300
+ # sagetex
301
+ *.sagetex.sage
302
+ *.sagetex.py
303
+ *.sagetex.scmd
304
+
305
+ # scrwfile
306
+ *.wrt
307
+
308
+ # svg
309
+ svg-inkscape/
310
+
311
+ # sympy
312
+ *.sout
313
+ *.sympy
314
+ sympy-plots-for-*.tex/
315
+
316
+ # pdfcomment
317
+ *.upa
318
+ *.upb
319
+
320
+ # pythontex
321
+ *.pytxcode
322
+ pythontex-files-*/
323
+
324
+ # tcolorbox
325
+ *.listing
326
+
327
+ # thmtools
328
+ *.loe
329
+
330
+ # TikZ & PGF
331
+ *.dpth
332
+ *.md5
333
+ *.auxlock
334
+
335
+ # titletoc
336
+ *.ptc
337
+
338
+ # todonotes
339
+ *.tdo
340
+
341
+ # vhistory
342
+ *.hst
343
+ *.ver
344
+
345
+ *.lod
346
+
347
+ # xcolor
348
+ *.xcp
349
+
350
+ # xmpincl
351
+ *.xmpi
352
+
353
+ # xindy
354
+ *.xdy
355
+
356
+ # xypic precompiled matrices and outlines
357
+ *.xyc
358
+ *.xyd
359
+
360
+ # endfloat
361
+ *.ttt
362
+ *.fff
363
+
364
+ # Latexian
365
+ TSWLatexianTemp*
366
+
367
+ ## Editors:
368
+ # WinEdt
369
+ *.bak
370
+ *.sav
371
+
372
+ # Texpad
373
+ .texpadtmp
374
+
375
+ # LyX
376
+ *.lyx~
377
+
378
+ # Kile
379
+ *.backup
380
+
381
+ # gummi
382
+ .*.swp
383
+
384
+ # KBibTeX
385
+ *~[0-9]*
386
+
387
+ # TeXnicCenter
388
+ *.tps
389
+
390
+ # auto folder when using emacs and auctex
391
+ ./auto/*
392
+ *.el
393
+
394
+ # expex forward references with \gathertags
395
+ *-tags.tex
396
+
397
+ # standalone packages
398
+ *.sta
399
+
400
+ # Makeindex log files
401
+ *.lpz
402
+
403
+ # xwatermark package
404
+ *.xwm
405
+
406
+ # REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
407
+ # option is specified. Footnotes are the stored in a file with suffix Notes.bib.
408
+ # Uncomment the next line to have this generated file ignored.
409
+ #*Notes.bib
410
+
411
+ ### JupyterNotebooks template
412
+ # gitignore template for Jupyter Notebooks
413
+ # website: http://jupyter.org/
414
+
415
+ .ipynb_checkpoints
416
+ */.ipynb_checkpoints/*
417
+
418
+ # IPython
419
+ profile_default/
420
+ ipython_config.py
421
+
422
+ # Remove previous ipynb_checkpoints
423
+ # git rm -r .ipynb_checkpoints/
424
+
425
+ ### LaTeX template
426
+ ## Core latex/pdflatex auxiliary files:
427
+ *.aux
428
+ *.lof
429
+ *.log
430
+ *.lot
431
+ *.fls
432
+ *.out
433
+ *.toc
434
+ *.fmt
435
+ *.fot
436
+ *.cb
437
+ *.cb2
438
+ .*.lb
439
+
440
+ ## Intermediate documents:
441
+ *.dvi
442
+ *.xdv
443
+ *-converted-to.*
444
+ # these rules might exclude image files for figures etc.
445
+ # *.ps
446
+ # *.eps
447
+ # *.pdf
448
+
449
+ ## Generated if empty string is given at "Please type another file name for output:"
450
+ .pdf
451
+
452
+ ## Bibliography auxiliary files (bibtex/biblatex/biber):
453
+ *.bbl
454
+ *.bcf
455
+ *.blg
456
+ *-blx.aux
457
+ *-blx.bib
458
+ *.run.xml
459
+
460
+ ## Build tool auxiliary files:
461
+ *.fdb_latexmk
462
+ *.synctex
463
+ *.synctex(busy)
464
+ *.synctex.gz
465
+ *.synctex.gz(busy)
466
+ *.pdfsync
467
+
468
+ ## Build tool directories for auxiliary files
469
+ # latexrun
470
+ latex.out/
471
+
472
+ ## Auxiliary and intermediate files from other packages:
473
+ # algorithms
474
+ *.alg
475
+ *.loa
476
+
477
+ # achemso
478
+ acs-*.bib
479
+
480
+ # amsthm
481
+ *.thm
482
+
483
+ # beamer
484
+ *.nav
485
+ *.pre
486
+ *.snm
487
+ *.vrb
488
+
489
+ # changes
490
+ *.soc
491
+
492
+ # comment
493
+ *.cut
494
+
495
+ # cprotect
496
+ *.cpt
497
+
498
+ # elsarticle (documentclass of Elsevier journals)
499
+ *.spl
500
+
501
+ # endnotes
502
+ *.ent
503
+
504
+ *.lox
505
+
506
+ # feynmf/feynmp
507
+ *.mf
508
+ *.mp
509
+ *.t[1-9]
510
+ *.t[1-9][0-9]
511
+ *.tfm
512
+
513
+ #(r)(e)ledmac/(r)(e)ledpar
514
+ *.end
515
+ *.?end
516
+ *.[1-9]
517
+ *.[1-9][0-9]
518
+ *.[1-9][0-9][0-9]
519
+ *.[1-9]R
520
+ *.[1-9][0-9]R
521
+ *.[1-9][0-9][0-9]R
522
+ *.eledsec[1-9]
523
+ *.eledsec[1-9]R
524
+ *.eledsec[1-9][0-9]
525
+ *.eledsec[1-9][0-9]R
526
+ *.eledsec[1-9][0-9][0-9]
527
+ *.eledsec[1-9][0-9][0-9]R
528
+
529
+ # glossaries
530
+ *.acn
531
+ *.acr
532
+ *.glg
533
+ *.glo
534
+ *.gls
535
+ *.glsdefs
536
+ *.lzo
537
+ *.lzs
538
+ *.slg
539
+ *.slo
540
+ *.sls
541
+
542
+ # uncomment this for glossaries-extra (will ignore makeindex's style files!)
543
+ # *.ist
544
+
545
+ # gnuplot
546
+ *.gnuplot
547
+ *.table
548
+
549
+ # gnuplottex
550
+ *-gnuplottex-*
551
+
552
+ # gregoriotex
553
+ *.gaux
554
+ *.glog
555
+ *.gtex
556
+
557
+ # htlatex
558
+ *.4ct
559
+ *.4tc
560
+ *.idv
561
+ *.lg
562
+ *.trc
563
+ *.xref
564
+
565
+ # hyperref
566
+ *.brf
567
+
568
+ # knitr
569
+ *-concordance.tex
570
+ # *.tikz
571
+ *-tikzDictionary
572
+
573
+ # listings
574
+ *.lol
575
+
576
+ # luatexja-ruby
577
+ *.ltjruby
578
+
579
+ # makeidx
580
+ *.idx
581
+ *.ilg
582
+ *.ind
583
+
584
+ # minitoc
585
+ *.maf
586
+ *.mlf
587
+ *.mlt
588
+ *.mtc[0-9]*
589
+ *.slf[0-9]*
590
+ *.slt[0-9]*
591
+ *.stc[0-9]*
592
+
593
+ # minted
594
+ _minted*
595
+ *.pyg
596
+
597
+ # morewrites
598
+ *.mw
599
+
600
+ # newpax
601
+ *.newpax
602
+
603
+ # nomencl
604
+ *.nlg
605
+ *.nlo
606
+ *.nls
607
+
608
+ # pax
609
+ *.pax
610
+
611
+ # pdfpcnotes
612
+ *.pdfpc
613
+
614
+ # sagetex
615
+ *.sagetex.sage
616
+ *.sagetex.py
617
+ *.sagetex.scmd
618
+
619
+ # scrwfile
620
+ *.wrt
621
+
622
+ # svg
623
+ svg-inkscape/
624
+
625
+ # sympy
626
+ *.sout
627
+ *.sympy
628
+ sympy-plots-for-*.tex/
629
+
630
+ # pdfcomment
631
+ *.upa
632
+ *.upb
633
+
634
+ # pythontex
635
+ *.pytxcode
636
+ pythontex-files-*/
637
+
638
+ # tcolorbox
639
+ *.listing
640
+
641
+ # thmtools
642
+ *.loe
643
+
644
+ # TikZ & PGF
645
+ *.dpth
646
+ *.md5
647
+ *.auxlock
648
+
649
+ # titletoc
650
+ *.ptc
651
+
652
+ # todonotes
653
+ *.tdo
654
+
655
+ # vhistory
656
+ *.hst
657
+ *.ver
658
+
659
+ *.lod
660
+
661
+ # xcolor
662
+ *.xcp
663
+
664
+ # xmpincl
665
+ *.xmpi
666
+
667
+ # xindy
668
+ *.xdy
669
+
670
+ # xypic precompiled matrices and outlines
671
+ *.xyc
672
+ *.xyd
673
+
674
+ # endfloat
675
+ *.ttt
676
+ *.fff
677
+
678
+ # Latexian
679
+ TSWLatexianTemp*
680
+
681
+ ## Editors:
682
+ # WinEdt
683
+ *.bak
684
+ *.sav
685
+
686
+ # Texpad
687
+ .texpadtmp
688
+
689
+ # LyX
690
+ *.lyx~
691
+
692
+ # Kile
693
+ *.backup
694
+
695
+ # gummi
696
+ .*.swp
697
+
698
+ # KBibTeX
699
+ *~[0-9]*
700
+
701
+ # TeXnicCenter
702
+ *.tps
703
+
704
+ # auto folder when using emacs and auctex
705
+ ./auto/*
706
+ *.el
707
+
708
+ # expex forward references with \gathertags
709
+ *-tags.tex
710
+
711
+ # standalone packages
712
+ *.sta
713
+
714
+ # Makeindex log files
715
+ *.lpz
716
+
717
+ # xwatermark package
718
+ *.xwm
719
+
720
+ # REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
721
+ # option is specified. Footnotes are the stored in a file with suffix Notes.bib.
722
+ # Uncomment the next line to have this generated file ignored.
723
+ #*Notes.bib
724
+
725
+ ### Python template
726
+ # Byte-compiled / optimized / DLL files
727
+ __pycache__/
728
+ *.py[cod]
729
+ *$py.class
730
+
731
+ # C extensions
732
+ *.so
733
+
734
+ # Distribution / packaging
735
+ .Python
736
+ build/
737
+ develop-eggs/
738
+ dist/
739
+ downloads/
740
+ eggs/
741
+ .eggs/
742
+ lib/
743
+ lib64/
744
+ parts/
745
+ sdist/
746
+ var/
747
+ wheels/
748
+ share/python-wheels/
749
+ *.egg-info/
750
+ .installed.cfg
751
+ *.egg
752
+ MANIFEST
753
+
754
+ # PyInstaller
755
+ # Usually these files are written by a python script from a template
756
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
757
+ *.manifest
758
+ *.spec
759
+
760
+ # Installer logs
761
+ pip-log.txt
762
+ pip-delete-this-directory.txt
763
+
764
+ # Unit test / coverage reports
765
+ htmlcov/
766
+ .tox/
767
+ .nox/
768
+ .coverage
769
+ .coverage.*
770
+ .cache
771
+ nosetests.xml
772
+ coverage.xml
773
+ *.cover
774
+ *.py,cover
775
+ .hypothesis/
776
+ .pytest_cache/
777
+ cover/
778
+
779
+ # Translations
780
+ *.mo
781
+ *.pot
782
+
783
+ # Django stuff:
784
+ *.log
785
+ local_settings.py
786
+ db.sqlite3
787
+ db.sqlite3-journal
788
+
789
+ # Flask stuff:
790
+ instance/
791
+ .webassets-cache
792
+
793
+ # Scrapy stuff:
794
+ .scrapy
795
+
796
+ # Sphinx documentation
797
+ docs/_build/
798
+
799
+ # PyBuilder
800
+ .pybuilder/
801
+ target/
802
+
803
+ # Jupyter Notebook
804
+ .ipynb_checkpoints
805
+
806
+ # IPython
807
+ profile_default/
808
+ ipython_config.py
809
+
810
+ # pyenv
811
+ # For a library or package, you might want to ignore these files since the code is
812
+ # intended to run in multiple environments; otherwise, check them in:
813
+ # .python-version
814
+
815
+ # pipenv
816
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
817
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
818
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
819
+ # install all needed dependencies.
820
+ #Pipfile.lock
821
+
822
+ # poetry
823
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
824
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
825
+ # commonly ignored for libraries.
826
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
827
+ #poetry.lock
828
+
829
+ # pdm
830
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
831
+ #pdm.lock
832
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
833
+ # in version control.
834
+ # https://pdm.fming.dev/#use-with-ide
835
+ .pdm.toml
836
+
837
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
838
+ __pypackages__/
839
+
840
+ # Celery stuff
841
+ celerybeat-schedule
842
+ celerybeat.pid
843
+
844
+ # SageMath parsed files
845
+ *.sage.py
846
+
847
+ # Environments
848
+ .env
849
+ .venv
850
+ env/
851
+ venv/
852
+ ENV/
853
+ env.bak/
854
+ venv.bak/
855
+
856
+ # Spyder project settings
857
+ .spyderproject
858
+ .spyproject
859
+
860
+ # Rope project settings
861
+ .ropeproject
862
+
863
+ # mkdocs documentation
864
+ /site
865
+
866
+ # mypy
867
+ .mypy_cache/
868
+ .dmypy.json
869
+ dmypy.json
870
+
871
+ # Pyre type checker
872
+ .pyre/
873
+
874
+ # pytype static type analyzer
875
+ .pytype/
876
+
877
+ # Cython debug symbols
878
+ cython_debug/
879
+
880
+ # PyCharm
881
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
882
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
883
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
884
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
885
+ #.idea/
886
+
config.json CHANGED
@@ -42,5 +42,5 @@
42
  "transformers_version": "4.37.2",
43
  "type_vocab_size": 2,
44
  "use_cache": true,
45
- "vocab_size": 28996
46
  }
 
42
  "transformers_version": "4.37.2",
43
  "type_vocab_size": 2,
44
  "use_cache": true,
45
+ "vocab_size": 52000
46
  }
data/.amlignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
data/.amlignore.amltmp ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
data/.gitkeep ADDED
File without changes
data/custom_vocab.txt ADDED
File without changes
model.safetensors CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:50f69282f3743ce8bae62eaaa651c74301fe373cb675fddcb86d9ef391b247b6
3
- size 433289224
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:fd4f0f89ac5e5fa87847f7574a0c4343175b23afd37140f15173824a44dfee61
3
+ size 503957528
notebooks/.amlignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
notebooks/.amlignore.amltmp ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-27-56Z.ipynb ADDED
@@ -0,0 +1,744 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {
11
+ "collapsed": false
12
+ }
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "source": [
17
+ "import pandas as pd\n",
18
+ "import numpy as np\n",
19
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
20
+ "import torch\n",
21
+ "import os\n",
22
+ "from typing import List\n",
23
+ "from datasets import load_dataset\n",
24
+ "import shap\n",
25
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
26
+ "\n",
27
+ "%load_ext watermark"
28
+ ],
29
+ "execution_count": null,
30
+ "outputs": [],
31
+ "metadata": {
32
+ "datalore": {
33
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
34
+ "type": "CODE",
35
+ "hide_input_from_viewers": false,
36
+ "hide_output_from_viewers": false,
37
+ "report_properties": {
38
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
39
+ }
40
+ }
41
+ }
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "outputs": [],
46
+ "source": [
47
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
48
+ "\n",
49
+ "SEED: int = 42\n",
50
+ "\n",
51
+ "BATCH_SIZE: int = 8\n",
52
+ "EPOCHS: int = 1\n",
53
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
54
+ "\n",
55
+ "CLASS_NAMES: List[str] = [\"DIED\",\n",
56
+ " \"ER_VISIT\",\n",
57
+ " \"HOSPITAL\",\n",
58
+ " \"OFC_VISIT\",\n",
59
+ " \"X_STAY\",\n",
60
+ " \"DISABLE\",\n",
61
+ " \"D_PRESENTED\"]\n",
62
+ "\n",
63
+ "# WandB configuration\n",
64
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
65
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
66
+ ],
67
+ "metadata": {
68
+ "collapsed": false
69
+ }
70
+ },
71
+ {
72
+ "cell_type": "code",
73
+ "outputs": [],
74
+ "source": [
75
+ "%watermark --iversion"
76
+ ],
77
+ "metadata": {
78
+ "collapsed": false
79
+ }
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "source": [
84
+ "!nvidia-smi"
85
+ ],
86
+ "execution_count": null,
87
+ "outputs": [],
88
+ "metadata": {
89
+ "datalore": {
90
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
91
+ "type": "CODE",
92
+ "hide_input_from_viewers": true,
93
+ "hide_output_from_viewers": true
94
+ }
95
+ }
96
+ },
97
+ {
98
+ "cell_type": "markdown",
99
+ "source": [
100
+ "## Loading the data set"
101
+ ],
102
+ "attachments": {},
103
+ "metadata": {
104
+ "datalore": {
105
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
106
+ "type": "MD",
107
+ "hide_input_from_viewers": false,
108
+ "hide_output_from_viewers": false,
109
+ "report_properties": {
110
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
111
+ }
112
+ }
113
+ }
114
+ },
115
+ {
116
+ "cell_type": "code",
117
+ "outputs": [],
118
+ "source": [
119
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
120
+ ],
121
+ "metadata": {
122
+ "collapsed": false
123
+ }
124
+ },
125
+ {
126
+ "cell_type": "markdown",
127
+ "source": [
128
+ "### Tokenisation and encoding"
129
+ ],
130
+ "metadata": {
131
+ "collapsed": false
132
+ }
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "source": [
137
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
138
+ ],
139
+ "execution_count": null,
140
+ "outputs": [],
141
+ "metadata": {
142
+ "datalore": {
143
+ "node_id": "I7n646PIscsUZRoHu6m7zm",
144
+ "type": "CODE",
145
+ "hide_input_from_viewers": true,
146
+ "hide_output_from_viewers": true
147
+ }
148
+ }
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "source": [
153
+ "def tokenize_and_encode(examples):\n",
154
+ " return tokenizer(examples[\"text\"], truncation=True)"
155
+ ],
156
+ "execution_count": null,
157
+ "outputs": [],
158
+ "metadata": {
159
+ "datalore": {
160
+ "node_id": "QBLOSI0yVIslV7v7qX9ZC3",
161
+ "type": "CODE",
162
+ "hide_input_from_viewers": true,
163
+ "hide_output_from_viewers": true
164
+ }
165
+ }
166
+ },
167
+ {
168
+ "cell_type": "code",
169
+ "source": [
170
+ "cols = dataset[\"train\"].column_names\n",
171
+ "cols.remove(\"labels\")\n",
172
+ "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
173
+ ],
174
+ "execution_count": null,
175
+ "outputs": [],
176
+ "metadata": {
177
+ "datalore": {
178
+ "node_id": "slHeNysZOX9uWS9PB7jFDb",
179
+ "type": "CODE",
180
+ "hide_input_from_viewers": true,
181
+ "hide_output_from_viewers": true
182
+ }
183
+ }
184
+ },
185
+ {
186
+ "cell_type": "markdown",
187
+ "source": [
188
+ "### Training"
189
+ ],
190
+ "metadata": {
191
+ "collapsed": false
192
+ }
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "source": [
197
+ "class MultiLabelTrainer(Trainer):\n",
198
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
199
+ " labels = inputs.pop(\"labels\")\n",
200
+ " outputs = model(**inputs)\n",
201
+ " logits = outputs.logits\n",
202
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
203
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
204
+ " labels.float().view(-1, self.model.config.num_labels))\n",
205
+ " return (loss, outputs) if return_outputs else loss"
206
+ ],
207
+ "execution_count": null,
208
+ "outputs": [],
209
+ "metadata": {
210
+ "datalore": {
211
+ "node_id": "itXWkbDw9sqbkMuDP84QoT",
212
+ "type": "CODE",
213
+ "hide_input_from_viewers": true,
214
+ "hide_output_from_viewers": true
215
+ }
216
+ }
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "source": [
221
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to(\"cuda\")"
222
+ ],
223
+ "execution_count": null,
224
+ "outputs": [],
225
+ "metadata": {
226
+ "datalore": {
227
+ "node_id": "ZQU7aW6TV45VmhHOQRzcnF",
228
+ "type": "CODE",
229
+ "hide_input_from_viewers": true,
230
+ "hide_output_from_viewers": true
231
+ }
232
+ }
233
+ },
234
+ {
235
+ "cell_type": "code",
236
+ "source": [
237
+ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
238
+ " y_pred = torch.from_numpy(y_pred)\n",
239
+ " y_true = torch.from_numpy(y_true)\n",
240
+ "\n",
241
+ " if sigmoid:\n",
242
+ " y_pred = y_pred.sigmoid()\n",
243
+ "\n",
244
+ " return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
245
+ ],
246
+ "execution_count": null,
247
+ "outputs": [],
248
+ "metadata": {
249
+ "datalore": {
250
+ "node_id": "swhgyyyxoGL8HjnXJtMuSW",
251
+ "type": "CODE",
252
+ "hide_input_from_viewers": true,
253
+ "hide_output_from_viewers": true
254
+ }
255
+ }
256
+ },
257
+ {
258
+ "cell_type": "code",
259
+ "source": [
260
+ "def compute_metrics(eval_pred):\n",
261
+ " predictions, labels = eval_pred\n",
262
+ " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
263
+ ],
264
+ "execution_count": null,
265
+ "outputs": [],
266
+ "metadata": {
267
+ "datalore": {
268
+ "node_id": "1Uq3HtkaBxtHNAnSwit5cI",
269
+ "type": "CODE",
270
+ "hide_input_from_viewers": true,
271
+ "hide_output_from_viewers": true
272
+ }
273
+ }
274
+ },
275
+ {
276
+ "cell_type": "code",
277
+ "source": [
278
+ "args = TrainingArguments(\n",
279
+ " output_dir=\"vaers\",\n",
280
+ " evaluation_strategy=\"epoch\",\n",
281
+ " learning_rate=2e-5,\n",
282
+ " per_device_train_batch_size=BATCH_SIZE,\n",
283
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
284
+ " num_train_epochs=EPOCHS,\n",
285
+ " weight_decay=.01,\n",
286
+ " report_to=[\"wandb\"]\n",
287
+ ")"
288
+ ],
289
+ "execution_count": null,
290
+ "outputs": [],
291
+ "metadata": {
292
+ "datalore": {
293
+ "node_id": "1iPZOTKPwSkTgX5dORqT89",
294
+ "type": "CODE",
295
+ "hide_input_from_viewers": true,
296
+ "hide_output_from_viewers": true
297
+ }
298
+ }
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "source": [
303
+ "multi_label_trainer = MultiLabelTrainer(\n",
304
+ " model, \n",
305
+ " args, \n",
306
+ " train_dataset=ds_enc[\"train\"], \n",
307
+ " eval_dataset=ds_enc[\"test\"], \n",
308
+ " compute_metrics=compute_metrics, \n",
309
+ " tokenizer=tokenizer\n",
310
+ ")"
311
+ ],
312
+ "execution_count": null,
313
+ "outputs": [],
314
+ "metadata": {
315
+ "datalore": {
316
+ "node_id": "bnRkNvRYltLun6gCEgL7v0",
317
+ "type": "CODE",
318
+ "hide_input_from_viewers": true,
319
+ "hide_output_from_viewers": true
320
+ }
321
+ }
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "source": [
326
+ "multi_label_trainer.evaluate()"
327
+ ],
328
+ "execution_count": null,
329
+ "outputs": [],
330
+ "metadata": {
331
+ "datalore": {
332
+ "node_id": "LO54PlDkWQdFrzV25FvduB",
333
+ "type": "CODE",
334
+ "hide_input_from_viewers": true,
335
+ "hide_output_from_viewers": true
336
+ }
337
+ }
338
+ },
339
+ {
340
+ "cell_type": "code",
341
+ "source": [
342
+ "multi_label_trainer.train()"
343
+ ],
344
+ "execution_count": null,
345
+ "outputs": [],
346
+ "metadata": {
347
+ "datalore": {
348
+ "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
349
+ "type": "CODE",
350
+ "hide_input_from_viewers": true,
351
+ "hide_output_from_viewers": true
352
+ }
353
+ }
354
+ },
355
+ {
356
+ "cell_type": "markdown",
357
+ "source": [
358
+ "### Evaluation"
359
+ ],
360
+ "metadata": {
361
+ "collapsed": false
362
+ }
363
+ },
364
+ {
365
+ "cell_type": "markdown",
366
+ "source": [
367
+ "We instantiate a classifier `pipeline` and push it to CUDA."
368
+ ],
369
+ "metadata": {
370
+ "collapsed": false
371
+ }
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "source": [
376
+ "classifier = pipeline(\"text-classification\", \n",
377
+ " model, \n",
378
+ " tokenizer=tokenizer, \n",
379
+ " device=\"cuda:0\")"
380
+ ],
381
+ "execution_count": null,
382
+ "outputs": [],
383
+ "metadata": {
384
+ "datalore": {
385
+ "node_id": "kHoUdBeqcyVXDSGv54C4aE",
386
+ "type": "CODE",
387
+ "hide_input_from_viewers": true,
388
+ "hide_output_from_viewers": true
389
+ }
390
+ }
391
+ },
392
+ {
393
+ "cell_type": "markdown",
394
+ "source": [
395
+ "We use the same tokenizer used for training to tokenize/encode the validation set."
396
+ ],
397
+ "metadata": {
398
+ "collapsed": false
399
+ }
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "source": [
404
+ "test_encodings = tokenizer.batch_encode_plus(dataset[\"validate\"][\"text\"], \n",
405
+ " max_length=255, \n",
406
+ " pad_to_max_length=True, \n",
407
+ " return_token_type_ids=True, \n",
408
+ " truncation=True)"
409
+ ],
410
+ "execution_count": null,
411
+ "outputs": [],
412
+ "metadata": {
413
+ "datalore": {
414
+ "node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
415
+ "type": "CODE",
416
+ "hide_input_from_viewers": true,
417
+ "hide_output_from_viewers": true
418
+ }
419
+ }
420
+ },
421
+ {
422
+ "cell_type": "markdown",
423
+ "source": [
424
+ "Once we've made the data loadable by putting it into a `DataLoader`, we "
425
+ ],
426
+ "metadata": {
427
+ "collapsed": false
428
+ }
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "source": [
433
+ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
434
+ " torch.tensor(test_encodings['attention_mask']), \n",
435
+ " torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n",
436
+ " torch.tensor(test_encodings['token_type_ids']))\n",
437
+ "test_dataloader = torch.utils.data.DataLoader(test_data, \n",
438
+ " sampler=torch.utils.data.SequentialSampler(test_data), \n",
439
+ " batch_size=BATCH_SIZE)"
440
+ ],
441
+ "execution_count": null,
442
+ "outputs": [],
443
+ "metadata": {
444
+ "datalore": {
445
+ "node_id": "MWfGq2tTkJNzFiDoUPq2X7",
446
+ "type": "CODE",
447
+ "hide_input_from_viewers": true,
448
+ "hide_output_from_viewers": true
449
+ }
450
+ }
451
+ },
452
+ {
453
+ "cell_type": "code",
454
+ "source": [
455
+ "model.eval()\n",
456
+ "\n",
457
+ "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
458
+ "\n",
459
+ "for i, batch in enumerate(test_dataloader):\n",
460
+ " batch = tuple(t.to(device) for t in batch)\n",
461
+ " # Unpack the inputs from our dataloader\n",
462
+ " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
463
+ " \n",
464
+ " with torch.no_grad():\n",
465
+ " outs = model(b_input_ids, attention_mask=b_input_mask)\n",
466
+ " b_logit_pred = outs[0]\n",
467
+ " pred_label = torch.sigmoid(b_logit_pred)\n",
468
+ "\n",
469
+ " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
470
+ " pred_label = pred_label.to('cpu').numpy()\n",
471
+ " b_labels = b_labels.to('cpu').numpy()\n",
472
+ "\n",
473
+ " tokenized_texts.append(b_input_ids)\n",
474
+ " logit_preds.append(b_logit_pred)\n",
475
+ " true_labels.append(b_labels)\n",
476
+ " pred_labels.append(pred_label)\n",
477
+ "\n",
478
+ "# Flatten outputs\n",
479
+ "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
480
+ "pred_labels = [item for sublist in pred_labels for item in sublist]\n",
481
+ "true_labels = [item for sublist in true_labels for item in sublist]\n",
482
+ "\n",
483
+ "# Converting flattened binary values to boolean values\n",
484
+ "true_bools = [tl == 1 for tl in true_labels]\n",
485
+ "pred_bools = [pl > 0.50 for pl in pred_labels] "
486
+ ],
487
+ "execution_count": null,
488
+ "outputs": [],
489
+ "metadata": {
490
+ "datalore": {
491
+ "node_id": "1SJCSrQTRCexFCNCIyRrzL",
492
+ "type": "CODE",
493
+ "hide_input_from_viewers": true,
494
+ "hide_output_from_viewers": true
495
+ }
496
+ }
497
+ },
498
+ {
499
+ "cell_type": "markdown",
500
+ "source": [
501
+ "We create a classification report:"
502
+ ],
503
+ "metadata": {
504
+ "collapsed": false
505
+ }
506
+ },
507
+ {
508
+ "cell_type": "code",
509
+ "source": [
510
+ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
511
+ "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
512
+ "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
513
+ "print(clf_report)"
514
+ ],
515
+ "execution_count": null,
516
+ "outputs": [],
517
+ "metadata": {
518
+ "datalore": {
519
+ "node_id": "eBprrgF086mznPbPVBpOLS",
520
+ "type": "CODE",
521
+ "hide_input_from_viewers": true,
522
+ "hide_output_from_viewers": true
523
+ }
524
+ }
525
+ },
526
+ {
527
+ "cell_type": "markdown",
528
+ "source": [
529
+ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
530
+ ],
531
+ "metadata": {
532
+ "collapsed": false
533
+ }
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "source": [
538
+ "# Creating a map of class names from class numbers\n",
539
+ "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
540
+ ],
541
+ "execution_count": null,
542
+ "outputs": [],
543
+ "metadata": {
544
+ "datalore": {
545
+ "node_id": "yELHY0IEwMlMw3x6e7hoD1",
546
+ "type": "CODE",
547
+ "hide_input_from_viewers": true,
548
+ "hide_output_from_viewers": true
549
+ }
550
+ }
551
+ },
552
+ {
553
+ "cell_type": "code",
554
+ "source": [
555
+ "true_label_idxs, pred_label_idxs = [], []\n",
556
+ "\n",
557
+ "for vals in true_bools:\n",
558
+ " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
559
+ "for vals in pred_bools:\n",
560
+ " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
561
+ ],
562
+ "execution_count": null,
563
+ "outputs": [],
564
+ "metadata": {
565
+ "datalore": {
566
+ "node_id": "jH0S35dDteUch01sa6me6e",
567
+ "type": "CODE",
568
+ "hide_input_from_viewers": true,
569
+ "hide_output_from_viewers": true
570
+ }
571
+ }
572
+ },
573
+ {
574
+ "cell_type": "code",
575
+ "source": [
576
+ "true_label_texts, pred_label_texts = [], []\n",
577
+ "\n",
578
+ "for vals in true_label_idxs:\n",
579
+ " if vals:\n",
580
+ " true_label_texts.append([idx2label[val] for val in vals])\n",
581
+ " else:\n",
582
+ " true_label_texts.append(vals)\n",
583
+ "\n",
584
+ "for vals in pred_label_idxs:\n",
585
+ " if vals:\n",
586
+ " pred_label_texts.append([idx2label[val] for val in vals])\n",
587
+ " else:\n",
588
+ " pred_label_texts.append(vals)"
589
+ ],
590
+ "execution_count": null,
591
+ "outputs": [],
592
+ "metadata": {
593
+ "datalore": {
594
+ "node_id": "h4vHL8XdGpayZ6xLGJUF6F",
595
+ "type": "CODE",
596
+ "hide_input_from_viewers": true,
597
+ "hide_output_from_viewers": true
598
+ }
599
+ }
600
+ },
601
+ {
602
+ "cell_type": "code",
603
+ "source": [
604
+ "symptom_texts = [tokenizer.decode(text,\n",
605
+ " skip_special_tokens=True,\n",
606
+ " clean_up_tokenization_spaces=False) for text in tokenized_texts]"
607
+ ],
608
+ "execution_count": null,
609
+ "outputs": [],
610
+ "metadata": {
611
+ "datalore": {
612
+ "node_id": "SxUmVHfQISEeptg1SawOmB",
613
+ "type": "CODE",
614
+ "hide_input_from_viewers": true,
615
+ "hide_output_from_viewers": true
616
+ }
617
+ }
618
+ },
619
+ {
620
+ "cell_type": "code",
621
+ "source": [
622
+ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
623
+ " 'true_labels': true_label_texts, \n",
624
+ " 'pred_labels':pred_label_texts})\n",
625
+ "comparisons_df.to_csv('comparisons.csv')\n",
626
+ "comparisons_df"
627
+ ],
628
+ "execution_count": null,
629
+ "outputs": [],
630
+ "metadata": {
631
+ "datalore": {
632
+ "node_id": "BxFNigNGRLTOqraI55BPSH",
633
+ "type": "CODE",
634
+ "hide_input_from_viewers": true,
635
+ "hide_output_from_viewers": true
636
+ }
637
+ }
638
+ },
639
+ {
640
+ "cell_type": "markdown",
641
+ "source": [
642
+ "### Shapley analysis"
643
+ ],
644
+ "metadata": {
645
+ "collapsed": false
646
+ }
647
+ },
648
+ {
649
+ "cell_type": "code",
650
+ "source": [
651
+ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
652
+ ],
653
+ "execution_count": null,
654
+ "outputs": [],
655
+ "metadata": {
656
+ "datalore": {
657
+ "node_id": "OpdZcoenX2HwzLdai7K5UA",
658
+ "type": "CODE",
659
+ "hide_input_from_viewers": true,
660
+ "hide_output_from_viewers": true
661
+ }
662
+ }
663
+ },
664
+ {
665
+ "cell_type": "code",
666
+ "source": [
667
+ "shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])"
668
+ ],
669
+ "execution_count": null,
670
+ "outputs": [],
671
+ "metadata": {
672
+ "datalore": {
673
+ "node_id": "FvbCMfIDlcf16YSvb8wNQv",
674
+ "type": "CODE",
675
+ "hide_input_from_viewers": true,
676
+ "hide_output_from_viewers": true
677
+ }
678
+ }
679
+ },
680
+ {
681
+ "cell_type": "code",
682
+ "source": [
683
+ "shap.plots.text(shap_values)"
684
+ ],
685
+ "execution_count": null,
686
+ "outputs": [],
687
+ "metadata": {
688
+ "datalore": {
689
+ "node_id": "TSxvakWLPCpjVMWi9ZdEbd",
690
+ "type": "CODE",
691
+ "hide_input_from_viewers": true,
692
+ "hide_output_from_viewers": true
693
+ }
694
+ }
695
+ }
696
+ ],
697
+ "metadata": {
698
+ "kernelspec": {
699
+ "name": "python3",
700
+ "language": "python",
701
+ "display_name": "Python 3 (ipykernel)"
702
+ },
703
+ "datalore": {
704
+ "computation_mode": "JUPYTER",
705
+ "package_manager": "pip",
706
+ "base_environment": "default",
707
+ "packages": [
708
+ {
709
+ "name": "datasets",
710
+ "version": "2.16.1",
711
+ "source": "PIP"
712
+ },
713
+ {
714
+ "name": "torch",
715
+ "version": "2.1.2",
716
+ "source": "PIP"
717
+ },
718
+ {
719
+ "name": "accelerate",
720
+ "version": "0.26.1",
721
+ "source": "PIP"
722
+ }
723
+ ],
724
+ "report_row_ids": [
725
+ "un8W7ez7ZwoGb5Co6nydEV",
726
+ "40nN9Hvgi1clHNV5RAemI5",
727
+ "TgRD90H5NSPpKS41OeXI1w",
728
+ "ZOm5BfUs3h1EGLaUkBGeEB",
729
+ "kOP0CZWNSk6vqE3wkPp7Vc",
730
+ "W4PWcOu2O2pRaZyoE2W80h",
731
+ "RolbOnQLIftk0vy9mIcz5M",
732
+ "8OPhUgbaNJmOdiq5D3a6vK",
733
+ "5Qrt3jSvSrpK6Ne1hS6shL",
734
+ "hTq7nFUrovN5Ao4u6dIYWZ",
735
+ "I8WNZLpJ1DVP2wiCW7YBIB",
736
+ "SawhU3I9BewSE1XBPstpNJ",
737
+ "80EtLEl2FIE4FqbWnUD3nT"
738
+ ],
739
+ "version": 3
740
+ }
741
+ },
742
+ "nbformat": 4,
743
+ "nbformat_minor": 4
744
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-52-4Z.ipynb ADDED
@@ -0,0 +1,788 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {
11
+ "collapsed": false
12
+ }
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "source": [
17
+ "import pandas as pd\n",
18
+ "import numpy as np\n",
19
+ "import torch\n",
20
+ "import os\n",
21
+ "from typing import List\n",
22
+ "from datasets import load_dataset\n",
23
+ "import shap\n",
24
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
25
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
26
+ "\n",
27
+ "%load_ext watermark"
28
+ ],
29
+ "outputs": [
30
+ {
31
+ "output_type": "error",
32
+ "ename": "ModuleNotFoundError",
33
+ "evalue": "No module named 'torch'",
34
+ "traceback": [
35
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
36
+ "\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
37
+ "Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List\n",
38
+ "\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
39
+ ]
40
+ }
41
+ ],
42
+ "execution_count": 2,
43
+ "metadata": {
44
+ "datalore": {
45
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
46
+ "type": "CODE",
47
+ "hide_input_from_viewers": false,
48
+ "hide_output_from_viewers": false,
49
+ "report_properties": {
50
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
51
+ }
52
+ },
53
+ "gather": {
54
+ "logged": 1706406690290
55
+ }
56
+ }
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
62
+ "\n",
63
+ "SEED: int = 42\n",
64
+ "\n",
65
+ "BATCH_SIZE: int = 8\n",
66
+ "EPOCHS: int = 1\n",
67
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
68
+ "\n",
69
+ "CLASS_NAMES: List[str] = [\"DIED\",\n",
70
+ " \"ER_VISIT\",\n",
71
+ " \"HOSPITAL\",\n",
72
+ " \"OFC_VISIT\",\n",
73
+ " \"X_STAY\",\n",
74
+ " \"DISABLE\",\n",
75
+ " \"D_PRESENTED\"]\n",
76
+ "\n",
77
+ "# WandB configuration\n",
78
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
79
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
80
+ ],
81
+ "outputs": [],
82
+ "execution_count": null,
83
+ "metadata": {
84
+ "collapsed": false
85
+ }
86
+ },
87
+ {
88
+ "cell_type": "code",
89
+ "source": [
90
+ "%watermark --iversion"
91
+ ],
92
+ "outputs": [],
93
+ "execution_count": null,
94
+ "metadata": {
95
+ "collapsed": false
96
+ }
97
+ },
98
+ {
99
+ "cell_type": "code",
100
+ "source": [
101
+ "!nvidia-smi"
102
+ ],
103
+ "outputs": [
104
+ {
105
+ "output_type": "stream",
106
+ "name": "stdout",
107
+ "text": "Sun Jan 28 01:31:42 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 28C P0 37W / 250W | 0MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 29C P0 36W / 250W | 0MiB / 16384MiB | 1% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
108
+ }
109
+ ],
110
+ "execution_count": 4,
111
+ "metadata": {
112
+ "datalore": {
113
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
114
+ "type": "CODE",
115
+ "hide_input_from_viewers": true,
116
+ "hide_output_from_viewers": true
117
+ }
118
+ }
119
+ },
120
+ {
121
+ "attachments": {},
122
+ "cell_type": "markdown",
123
+ "source": [
124
+ "## Loading the data set"
125
+ ],
126
+ "metadata": {
127
+ "datalore": {
128
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
129
+ "type": "MD",
130
+ "hide_input_from_viewers": false,
131
+ "hide_output_from_viewers": false,
132
+ "report_properties": {
133
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
134
+ }
135
+ }
136
+ }
137
+ },
138
+ {
139
+ "cell_type": "code",
140
+ "source": [
141
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
142
+ ],
143
+ "outputs": [],
144
+ "execution_count": null,
145
+ "metadata": {
146
+ "collapsed": false
147
+ }
148
+ },
149
+ {
150
+ "cell_type": "markdown",
151
+ "source": [
152
+ "### Tokenisation and encoding"
153
+ ],
154
+ "metadata": {
155
+ "collapsed": false
156
+ }
157
+ },
158
+ {
159
+ "cell_type": "code",
160
+ "source": [
161
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
162
+ ],
163
+ "outputs": [],
164
+ "execution_count": null,
165
+ "metadata": {
166
+ "datalore": {
167
+ "node_id": "I7n646PIscsUZRoHu6m7zm",
168
+ "type": "CODE",
169
+ "hide_input_from_viewers": true,
170
+ "hide_output_from_viewers": true
171
+ }
172
+ }
173
+ },
174
+ {
175
+ "cell_type": "code",
176
+ "source": [
177
+ "def tokenize_and_encode(examples):\n",
178
+ " return tokenizer(examples[\"text\"], truncation=True)"
179
+ ],
180
+ "outputs": [],
181
+ "execution_count": null,
182
+ "metadata": {
183
+ "datalore": {
184
+ "node_id": "QBLOSI0yVIslV7v7qX9ZC3",
185
+ "type": "CODE",
186
+ "hide_input_from_viewers": true,
187
+ "hide_output_from_viewers": true
188
+ }
189
+ }
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "source": [
194
+ "cols = dataset[\"train\"].column_names\n",
195
+ "cols.remove(\"labels\")\n",
196
+ "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
197
+ ],
198
+ "outputs": [],
199
+ "execution_count": null,
200
+ "metadata": {
201
+ "datalore": {
202
+ "node_id": "slHeNysZOX9uWS9PB7jFDb",
203
+ "type": "CODE",
204
+ "hide_input_from_viewers": true,
205
+ "hide_output_from_viewers": true
206
+ }
207
+ }
208
+ },
209
+ {
210
+ "cell_type": "markdown",
211
+ "source": [
212
+ "### Training"
213
+ ],
214
+ "metadata": {
215
+ "collapsed": false
216
+ }
217
+ },
218
+ {
219
+ "cell_type": "code",
220
+ "source": [
221
+ "class MultiLabelTrainer(Trainer):\n",
222
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
223
+ " labels = inputs.pop(\"labels\")\n",
224
+ " outputs = model(**inputs)\n",
225
+ " logits = outputs.logits\n",
226
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
227
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
228
+ " labels.float().view(-1, self.model.config.num_labels))\n",
229
+ " return (loss, outputs) if return_outputs else loss"
230
+ ],
231
+ "outputs": [],
232
+ "execution_count": null,
233
+ "metadata": {
234
+ "datalore": {
235
+ "node_id": "itXWkbDw9sqbkMuDP84QoT",
236
+ "type": "CODE",
237
+ "hide_input_from_viewers": true,
238
+ "hide_output_from_viewers": true
239
+ }
240
+ }
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "source": [
245
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to(\"cuda\")"
246
+ ],
247
+ "outputs": [],
248
+ "execution_count": null,
249
+ "metadata": {
250
+ "datalore": {
251
+ "node_id": "ZQU7aW6TV45VmhHOQRzcnF",
252
+ "type": "CODE",
253
+ "hide_input_from_viewers": true,
254
+ "hide_output_from_viewers": true
255
+ }
256
+ }
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "source": [
261
+ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
262
+ " y_pred = torch.from_numpy(y_pred)\n",
263
+ " y_true = torch.from_numpy(y_true)\n",
264
+ "\n",
265
+ " if sigmoid:\n",
266
+ " y_pred = y_pred.sigmoid()\n",
267
+ "\n",
268
+ " return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
269
+ ],
270
+ "outputs": [],
271
+ "execution_count": null,
272
+ "metadata": {
273
+ "datalore": {
274
+ "node_id": "swhgyyyxoGL8HjnXJtMuSW",
275
+ "type": "CODE",
276
+ "hide_input_from_viewers": true,
277
+ "hide_output_from_viewers": true
278
+ }
279
+ }
280
+ },
281
+ {
282
+ "cell_type": "code",
283
+ "source": [
284
+ "def compute_metrics(eval_pred):\n",
285
+ " predictions, labels = eval_pred\n",
286
+ " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
287
+ ],
288
+ "outputs": [],
289
+ "execution_count": null,
290
+ "metadata": {
291
+ "datalore": {
292
+ "node_id": "1Uq3HtkaBxtHNAnSwit5cI",
293
+ "type": "CODE",
294
+ "hide_input_from_viewers": true,
295
+ "hide_output_from_viewers": true
296
+ }
297
+ }
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "source": [
302
+ "args = TrainingArguments(\n",
303
+ " output_dir=\"vaers\",\n",
304
+ " evaluation_strategy=\"epoch\",\n",
305
+ " learning_rate=2e-5,\n",
306
+ " per_device_train_batch_size=BATCH_SIZE,\n",
307
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
308
+ " num_train_epochs=EPOCHS,\n",
309
+ " weight_decay=.01,\n",
310
+ " report_to=[\"wandb\"]\n",
311
+ ")"
312
+ ],
313
+ "outputs": [],
314
+ "execution_count": null,
315
+ "metadata": {
316
+ "datalore": {
317
+ "node_id": "1iPZOTKPwSkTgX5dORqT89",
318
+ "type": "CODE",
319
+ "hide_input_from_viewers": true,
320
+ "hide_output_from_viewers": true
321
+ }
322
+ }
323
+ },
324
+ {
325
+ "cell_type": "code",
326
+ "source": [
327
+ "multi_label_trainer = MultiLabelTrainer(\n",
328
+ " model, \n",
329
+ " args, \n",
330
+ " train_dataset=ds_enc[\"train\"], \n",
331
+ " eval_dataset=ds_enc[\"test\"], \n",
332
+ " compute_metrics=compute_metrics, \n",
333
+ " tokenizer=tokenizer\n",
334
+ ")"
335
+ ],
336
+ "outputs": [],
337
+ "execution_count": null,
338
+ "metadata": {
339
+ "datalore": {
340
+ "node_id": "bnRkNvRYltLun6gCEgL7v0",
341
+ "type": "CODE",
342
+ "hide_input_from_viewers": true,
343
+ "hide_output_from_viewers": true
344
+ }
345
+ }
346
+ },
347
+ {
348
+ "cell_type": "code",
349
+ "source": [
350
+ "multi_label_trainer.evaluate()"
351
+ ],
352
+ "outputs": [],
353
+ "execution_count": null,
354
+ "metadata": {
355
+ "datalore": {
356
+ "node_id": "LO54PlDkWQdFrzV25FvduB",
357
+ "type": "CODE",
358
+ "hide_input_from_viewers": true,
359
+ "hide_output_from_viewers": true
360
+ }
361
+ }
362
+ },
363
+ {
364
+ "cell_type": "code",
365
+ "source": [
366
+ "multi_label_trainer.train()"
367
+ ],
368
+ "outputs": [],
369
+ "execution_count": null,
370
+ "metadata": {
371
+ "datalore": {
372
+ "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
373
+ "type": "CODE",
374
+ "hide_input_from_viewers": true,
375
+ "hide_output_from_viewers": true
376
+ }
377
+ }
378
+ },
379
+ {
380
+ "cell_type": "markdown",
381
+ "source": [
382
+ "### Evaluation"
383
+ ],
384
+ "metadata": {
385
+ "collapsed": false
386
+ }
387
+ },
388
+ {
389
+ "cell_type": "markdown",
390
+ "source": [
391
+ "We instantiate a classifier `pipeline` and push it to CUDA."
392
+ ],
393
+ "metadata": {
394
+ "collapsed": false
395
+ }
396
+ },
397
+ {
398
+ "cell_type": "code",
399
+ "source": [
400
+ "classifier = pipeline(\"text-classification\", \n",
401
+ " model, \n",
402
+ " tokenizer=tokenizer, \n",
403
+ " device=\"cuda:0\")"
404
+ ],
405
+ "outputs": [],
406
+ "execution_count": null,
407
+ "metadata": {
408
+ "datalore": {
409
+ "node_id": "kHoUdBeqcyVXDSGv54C4aE",
410
+ "type": "CODE",
411
+ "hide_input_from_viewers": true,
412
+ "hide_output_from_viewers": true
413
+ }
414
+ }
415
+ },
416
+ {
417
+ "cell_type": "markdown",
418
+ "source": [
419
+ "We use the same tokenizer used for training to tokenize/encode the validation set."
420
+ ],
421
+ "metadata": {
422
+ "collapsed": false
423
+ }
424
+ },
425
+ {
426
+ "cell_type": "code",
427
+ "source": [
428
+ "test_encodings = tokenizer.batch_encode_plus(dataset[\"validate\"][\"text\"], \n",
429
+ " max_length=255, \n",
430
+ " pad_to_max_length=True, \n",
431
+ " return_token_type_ids=True, \n",
432
+ " truncation=True)"
433
+ ],
434
+ "outputs": [],
435
+ "execution_count": null,
436
+ "metadata": {
437
+ "datalore": {
438
+ "node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
439
+ "type": "CODE",
440
+ "hide_input_from_viewers": true,
441
+ "hide_output_from_viewers": true
442
+ }
443
+ }
444
+ },
445
+ {
446
+ "cell_type": "markdown",
447
+ "source": [
448
+ "Once we've made the data loadable by putting it into a `DataLoader`, we "
449
+ ],
450
+ "metadata": {
451
+ "collapsed": false
452
+ }
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "source": [
457
+ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
458
+ " torch.tensor(test_encodings['attention_mask']), \n",
459
+ " torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n",
460
+ " torch.tensor(test_encodings['token_type_ids']))\n",
461
+ "test_dataloader = torch.utils.data.DataLoader(test_data, \n",
462
+ " sampler=torch.utils.data.SequentialSampler(test_data), \n",
463
+ " batch_size=BATCH_SIZE)"
464
+ ],
465
+ "outputs": [],
466
+ "execution_count": null,
467
+ "metadata": {
468
+ "datalore": {
469
+ "node_id": "MWfGq2tTkJNzFiDoUPq2X7",
470
+ "type": "CODE",
471
+ "hide_input_from_viewers": true,
472
+ "hide_output_from_viewers": true
473
+ }
474
+ }
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "source": [
479
+ "model.eval()\n",
480
+ "\n",
481
+ "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
482
+ "\n",
483
+ "for i, batch in enumerate(test_dataloader):\n",
484
+ " batch = tuple(t.to(device) for t in batch)\n",
485
+ " # Unpack the inputs from our dataloader\n",
486
+ " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
487
+ " \n",
488
+ " with torch.no_grad():\n",
489
+ " outs = model(b_input_ids, attention_mask=b_input_mask)\n",
490
+ " b_logit_pred = outs[0]\n",
491
+ " pred_label = torch.sigmoid(b_logit_pred)\n",
492
+ "\n",
493
+ " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
494
+ " pred_label = pred_label.to('cpu').numpy()\n",
495
+ " b_labels = b_labels.to('cpu').numpy()\n",
496
+ "\n",
497
+ " tokenized_texts.append(b_input_ids)\n",
498
+ " logit_preds.append(b_logit_pred)\n",
499
+ " true_labels.append(b_labels)\n",
500
+ " pred_labels.append(pred_label)\n",
501
+ "\n",
502
+ "# Flatten outputs\n",
503
+ "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
504
+ "pred_labels = [item for sublist in pred_labels for item in sublist]\n",
505
+ "true_labels = [item for sublist in true_labels for item in sublist]\n",
506
+ "\n",
507
+ "# Converting flattened binary values to boolean values\n",
508
+ "true_bools = [tl == 1 for tl in true_labels]\n",
509
+ "pred_bools = [pl > 0.50 for pl in pred_labels] "
510
+ ],
511
+ "outputs": [],
512
+ "execution_count": null,
513
+ "metadata": {
514
+ "datalore": {
515
+ "node_id": "1SJCSrQTRCexFCNCIyRrzL",
516
+ "type": "CODE",
517
+ "hide_input_from_viewers": true,
518
+ "hide_output_from_viewers": true
519
+ }
520
+ }
521
+ },
522
+ {
523
+ "cell_type": "markdown",
524
+ "source": [
525
+ "We create a classification report:"
526
+ ],
527
+ "metadata": {
528
+ "collapsed": false
529
+ }
530
+ },
531
+ {
532
+ "cell_type": "code",
533
+ "source": [
534
+ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
535
+ "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
536
+ "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
537
+ "print(clf_report)"
538
+ ],
539
+ "outputs": [],
540
+ "execution_count": null,
541
+ "metadata": {
542
+ "datalore": {
543
+ "node_id": "eBprrgF086mznPbPVBpOLS",
544
+ "type": "CODE",
545
+ "hide_input_from_viewers": true,
546
+ "hide_output_from_viewers": true
547
+ }
548
+ }
549
+ },
550
+ {
551
+ "cell_type": "markdown",
552
+ "source": [
553
+ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
554
+ ],
555
+ "metadata": {
556
+ "collapsed": false
557
+ }
558
+ },
559
+ {
560
+ "cell_type": "code",
561
+ "source": [
562
+ "# Creating a map of class names from class numbers\n",
563
+ "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
564
+ ],
565
+ "outputs": [],
566
+ "execution_count": null,
567
+ "metadata": {
568
+ "datalore": {
569
+ "node_id": "yELHY0IEwMlMw3x6e7hoD1",
570
+ "type": "CODE",
571
+ "hide_input_from_viewers": true,
572
+ "hide_output_from_viewers": true
573
+ }
574
+ }
575
+ },
576
+ {
577
+ "cell_type": "code",
578
+ "source": [
579
+ "true_label_idxs, pred_label_idxs = [], []\n",
580
+ "\n",
581
+ "for vals in true_bools:\n",
582
+ " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
583
+ "for vals in pred_bools:\n",
584
+ " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
585
+ ],
586
+ "outputs": [],
587
+ "execution_count": null,
588
+ "metadata": {
589
+ "datalore": {
590
+ "node_id": "jH0S35dDteUch01sa6me6e",
591
+ "type": "CODE",
592
+ "hide_input_from_viewers": true,
593
+ "hide_output_from_viewers": true
594
+ }
595
+ }
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "source": [
600
+ "true_label_texts, pred_label_texts = [], []\n",
601
+ "\n",
602
+ "for vals in true_label_idxs:\n",
603
+ " if vals:\n",
604
+ " true_label_texts.append([idx2label[val] for val in vals])\n",
605
+ " else:\n",
606
+ " true_label_texts.append(vals)\n",
607
+ "\n",
608
+ "for vals in pred_label_idxs:\n",
609
+ " if vals:\n",
610
+ " pred_label_texts.append([idx2label[val] for val in vals])\n",
611
+ " else:\n",
612
+ " pred_label_texts.append(vals)"
613
+ ],
614
+ "outputs": [],
615
+ "execution_count": null,
616
+ "metadata": {
617
+ "datalore": {
618
+ "node_id": "h4vHL8XdGpayZ6xLGJUF6F",
619
+ "type": "CODE",
620
+ "hide_input_from_viewers": true,
621
+ "hide_output_from_viewers": true
622
+ }
623
+ }
624
+ },
625
+ {
626
+ "cell_type": "code",
627
+ "source": [
628
+ "symptom_texts = [tokenizer.decode(text,\n",
629
+ " skip_special_tokens=True,\n",
630
+ " clean_up_tokenization_spaces=False) for text in tokenized_texts]"
631
+ ],
632
+ "outputs": [],
633
+ "execution_count": null,
634
+ "metadata": {
635
+ "datalore": {
636
+ "node_id": "SxUmVHfQISEeptg1SawOmB",
637
+ "type": "CODE",
638
+ "hide_input_from_viewers": true,
639
+ "hide_output_from_viewers": true
640
+ }
641
+ }
642
+ },
643
+ {
644
+ "cell_type": "code",
645
+ "source": [
646
+ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
647
+ " 'true_labels': true_label_texts, \n",
648
+ " 'pred_labels':pred_label_texts})\n",
649
+ "comparisons_df.to_csv('comparisons.csv')\n",
650
+ "comparisons_df"
651
+ ],
652
+ "outputs": [],
653
+ "execution_count": null,
654
+ "metadata": {
655
+ "datalore": {
656
+ "node_id": "BxFNigNGRLTOqraI55BPSH",
657
+ "type": "CODE",
658
+ "hide_input_from_viewers": true,
659
+ "hide_output_from_viewers": true
660
+ }
661
+ }
662
+ },
663
+ {
664
+ "cell_type": "markdown",
665
+ "source": [
666
+ "### Shapley analysis"
667
+ ],
668
+ "metadata": {
669
+ "collapsed": false
670
+ }
671
+ },
672
+ {
673
+ "cell_type": "code",
674
+ "source": [
675
+ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
676
+ ],
677
+ "outputs": [],
678
+ "execution_count": null,
679
+ "metadata": {
680
+ "datalore": {
681
+ "node_id": "OpdZcoenX2HwzLdai7K5UA",
682
+ "type": "CODE",
683
+ "hide_input_from_viewers": true,
684
+ "hide_output_from_viewers": true
685
+ }
686
+ }
687
+ },
688
+ {
689
+ "cell_type": "code",
690
+ "source": [
691
+ "shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])"
692
+ ],
693
+ "outputs": [],
694
+ "execution_count": null,
695
+ "metadata": {
696
+ "datalore": {
697
+ "node_id": "FvbCMfIDlcf16YSvb8wNQv",
698
+ "type": "CODE",
699
+ "hide_input_from_viewers": true,
700
+ "hide_output_from_viewers": true
701
+ }
702
+ }
703
+ },
704
+ {
705
+ "cell_type": "code",
706
+ "source": [
707
+ "shap.plots.text(shap_values)"
708
+ ],
709
+ "outputs": [],
710
+ "execution_count": null,
711
+ "metadata": {
712
+ "datalore": {
713
+ "node_id": "TSxvakWLPCpjVMWi9ZdEbd",
714
+ "type": "CODE",
715
+ "hide_input_from_viewers": true,
716
+ "hide_output_from_viewers": true
717
+ }
718
+ }
719
+ }
720
+ ],
721
+ "metadata": {
722
+ "kernelspec": {
723
+ "name": "python3",
724
+ "language": "python",
725
+ "display_name": "Python 3 (ipykernel)"
726
+ },
727
+ "datalore": {
728
+ "computation_mode": "JUPYTER",
729
+ "package_manager": "pip",
730
+ "base_environment": "default",
731
+ "packages": [
732
+ {
733
+ "name": "datasets",
734
+ "version": "2.16.1",
735
+ "source": "PIP"
736
+ },
737
+ {
738
+ "name": "torch",
739
+ "version": "2.1.2",
740
+ "source": "PIP"
741
+ },
742
+ {
743
+ "name": "accelerate",
744
+ "version": "0.26.1",
745
+ "source": "PIP"
746
+ }
747
+ ],
748
+ "report_row_ids": [
749
+ "un8W7ez7ZwoGb5Co6nydEV",
750
+ "40nN9Hvgi1clHNV5RAemI5",
751
+ "TgRD90H5NSPpKS41OeXI1w",
752
+ "ZOm5BfUs3h1EGLaUkBGeEB",
753
+ "kOP0CZWNSk6vqE3wkPp7Vc",
754
+ "W4PWcOu2O2pRaZyoE2W80h",
755
+ "RolbOnQLIftk0vy9mIcz5M",
756
+ "8OPhUgbaNJmOdiq5D3a6vK",
757
+ "5Qrt3jSvSrpK6Ne1hS6shL",
758
+ "hTq7nFUrovN5Ao4u6dIYWZ",
759
+ "I8WNZLpJ1DVP2wiCW7YBIB",
760
+ "SawhU3I9BewSE1XBPstpNJ",
761
+ "80EtLEl2FIE4FqbWnUD3nT"
762
+ ],
763
+ "version": 3
764
+ },
765
+ "microsoft": {
766
+ "ms_spell_check": {
767
+ "ms_spell_check_language": "en"
768
+ }
769
+ },
770
+ "language_info": {
771
+ "name": "python",
772
+ "version": "3.8.5",
773
+ "mimetype": "text/x-python",
774
+ "codemirror_mode": {
775
+ "name": "ipython",
776
+ "version": 3
777
+ },
778
+ "pygments_lexer": "ipython3",
779
+ "nbconvert_exporter": "python",
780
+ "file_extension": ".py"
781
+ },
782
+ "nteract": {
783
+ "version": "nteract-front-end@1.0.0"
784
+ }
785
+ },
786
+ "nbformat": 4,
787
+ "nbformat_minor": 4
788
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-13-2-30Z.ipynb ADDED
@@ -0,0 +1,1147 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {
11
+ "collapsed": false
12
+ }
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "source": [
17
+ "%pip install accelerate -U"
18
+ ],
19
+ "outputs": [
20
+ {
21
+ "output_type": "stream",
22
+ "name": "stdout",
23
+ "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nNote: you may need to restart the kernel to use updated packages.\n"
24
+ }
25
+ ],
26
+ "execution_count": 1,
27
+ "metadata": {
28
+ "jupyter": {
29
+ "source_hidden": false,
30
+ "outputs_hidden": false
31
+ },
32
+ "nteract": {
33
+ "transient": {
34
+ "deleting": false
35
+ }
36
+ }
37
+ }
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "source": [
42
+ "%pip install transformers datasets shap watermark wandb"
43
+ ],
44
+ "outputs": [
45
+ {
46
+ "output_type": "stream",
47
+ "name": "stdout",
48
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nNote: you may need to restart the kernel to use updated packages.\n"
49
+ }
50
+ ],
51
+ "execution_count": 2,
52
+ "metadata": {
53
+ "jupyter": {
54
+ "source_hidden": false,
55
+ "outputs_hidden": false
56
+ },
57
+ "nteract": {
58
+ "transient": {
59
+ "deleting": false
60
+ }
61
+ }
62
+ }
63
+ },
64
+ {
65
+ "cell_type": "code",
66
+ "source": [
67
+ "import pandas as pd\n",
68
+ "import numpy as np\n",
69
+ "import torch\n",
70
+ "import os\n",
71
+ "from typing import List\n",
72
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
73
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
74
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
75
+ "from pyarrow import Table\n",
76
+ "import shap\n",
77
+ "\n",
78
+ "%load_ext watermark"
79
+ ],
80
+ "outputs": [
81
+ {
82
+ "output_type": "stream",
83
+ "name": "stderr",
84
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 04:14:37.393442: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 04:14:38.436146: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 04:14:38.436275: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 04:14:38.436289: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
85
+ }
86
+ ],
87
+ "execution_count": 3,
88
+ "metadata": {
89
+ "datalore": {
90
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
91
+ "type": "CODE",
92
+ "hide_input_from_viewers": false,
93
+ "hide_output_from_viewers": false,
94
+ "report_properties": {
95
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
96
+ }
97
+ },
98
+ "gather": {
99
+ "logged": 1706415280692
100
+ }
101
+ }
102
+ },
103
+ {
104
+ "cell_type": "code",
105
+ "source": [
106
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
107
+ "\n",
108
+ "SEED: int = 42\n",
109
+ "\n",
110
+ "BATCH_SIZE: int = 8\n",
111
+ "EPOCHS: int = 1\n",
112
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
113
+ "\n",
114
+ "CLASS_NAMES: List[str] = [\"DIED\",\n",
115
+ " \"ER_VISIT\",\n",
116
+ " \"HOSPITAL\",\n",
117
+ " \"OFC_VISIT\",\n",
118
+ " #\"X_STAY\", # pruned\n",
119
+ " #\"DISABLE\", # pruned\n",
120
+ " #\"D_PRESENTED\" # pruned\n",
121
+ " ]\n",
122
+ "\n",
123
+ "\n",
124
+ "\n",
125
+ "\n",
126
+ "# WandB configuration\n",
127
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
128
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
129
+ ],
130
+ "outputs": [],
131
+ "execution_count": 4,
132
+ "metadata": {
133
+ "collapsed": false,
134
+ "gather": {
135
+ "logged": 1706415281102
136
+ }
137
+ }
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "source": [
142
+ "%watermark --iversion"
143
+ ],
144
+ "outputs": [
145
+ {
146
+ "output_type": "stream",
147
+ "name": "stdout",
148
+ "text": "re : 2.2.1\nlogging: 0.5.1.2\nnumpy : 1.23.5\nshap : 0.44.1\npandas : 2.0.2\ntorch : 1.12.0\n\n"
149
+ }
150
+ ],
151
+ "execution_count": 5,
152
+ "metadata": {
153
+ "collapsed": false
154
+ }
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "source": [
159
+ "!nvidia-smi"
160
+ ],
161
+ "outputs": [
162
+ {
163
+ "output_type": "stream",
164
+ "name": "stdout",
165
+ "text": "Sun Jan 28 04:14:40 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 29C P0 37W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 28C P0 36W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
166
+ }
167
+ ],
168
+ "execution_count": 6,
169
+ "metadata": {
170
+ "datalore": {
171
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
172
+ "type": "CODE",
173
+ "hide_input_from_viewers": true,
174
+ "hide_output_from_viewers": true
175
+ }
176
+ }
177
+ },
178
+ {
179
+ "attachments": {},
180
+ "cell_type": "markdown",
181
+ "source": [
182
+ "## Loading the data set"
183
+ ],
184
+ "metadata": {
185
+ "datalore": {
186
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
187
+ "type": "MD",
188
+ "hide_input_from_viewers": false,
189
+ "hide_output_from_viewers": false,
190
+ "report_properties": {
191
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
192
+ }
193
+ }
194
+ }
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "source": [
199
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
200
+ ],
201
+ "outputs": [],
202
+ "execution_count": 7,
203
+ "metadata": {
204
+ "collapsed": false,
205
+ "gather": {
206
+ "logged": 1706415283301
207
+ }
208
+ }
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "source": [
213
+ "We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
214
+ ],
215
+ "metadata": {
216
+ "nteract": {
217
+ "transient": {
218
+ "deleting": false
219
+ }
220
+ }
221
+ }
222
+ },
223
+ {
224
+ "cell_type": "code",
225
+ "source": [
226
+ "ds = DatasetDict()\n",
227
+ "\n",
228
+ "for i in [\"test\", \"train\", \"val\"]:\n",
229
+ " tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
230
+ " ds[i] = Dataset(tab)\n",
231
+ "\n",
232
+ "dataset = ds"
233
+ ],
234
+ "outputs": [],
235
+ "execution_count": 8,
236
+ "metadata": {
237
+ "jupyter": {
238
+ "source_hidden": false,
239
+ "outputs_hidden": false
240
+ },
241
+ "nteract": {
242
+ "transient": {
243
+ "deleting": false
244
+ }
245
+ },
246
+ "gather": {
247
+ "logged": 1706415283944
248
+ }
249
+ }
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "source": [
254
+ "### Tokenisation and encoding"
255
+ ],
256
+ "metadata": {
257
+ "collapsed": false
258
+ }
259
+ },
260
+ {
261
+ "cell_type": "code",
262
+ "source": [
263
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
264
+ ],
265
+ "outputs": [],
266
+ "execution_count": 9,
267
+ "metadata": {
268
+ "datalore": {
269
+ "node_id": "I7n646PIscsUZRoHu6m7zm",
270
+ "type": "CODE",
271
+ "hide_input_from_viewers": true,
272
+ "hide_output_from_viewers": true
273
+ },
274
+ "gather": {
275
+ "logged": 1706415284206
276
+ }
277
+ }
278
+ },
279
+ {
280
+ "cell_type": "code",
281
+ "source": [
282
+ "def tokenize_and_encode(examples):\n",
283
+ " return tokenizer(examples[\"text\"], truncation=True)"
284
+ ],
285
+ "outputs": [],
286
+ "execution_count": 10,
287
+ "metadata": {
288
+ "datalore": {
289
+ "node_id": "QBLOSI0yVIslV7v7qX9ZC3",
290
+ "type": "CODE",
291
+ "hide_input_from_viewers": true,
292
+ "hide_output_from_viewers": true
293
+ },
294
+ "gather": {
295
+ "logged": 1706415284614
296
+ }
297
+ }
298
+ },
299
+ {
300
+ "cell_type": "code",
301
+ "source": [
302
+ "cols = dataset[\"train\"].column_names\n",
303
+ "cols.remove(\"labels\")\n",
304
+ "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
305
+ ],
306
+ "outputs": [
307
+ {
308
+ "output_type": "stream",
309
+ "name": "stderr",
310
+ "text": "Map: 100%|██████████| 15786/15786 [00:01<00:00, 10213.76 examples/s]\nMap: 100%|██████████| 73667/73667 [00:07<00:00, 10215.55 examples/s]\nMap: 100%|██████████| 15785/15785 [00:01<00:00, 10172.52 examples/s]\n"
311
+ }
312
+ ],
313
+ "execution_count": 11,
314
+ "metadata": {
315
+ "datalore": {
316
+ "node_id": "slHeNysZOX9uWS9PB7jFDb",
317
+ "type": "CODE",
318
+ "hide_input_from_viewers": true,
319
+ "hide_output_from_viewers": true
320
+ },
321
+ "gather": {
322
+ "logged": 1706415294450
323
+ }
324
+ }
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "source": [
329
+ "### Training"
330
+ ],
331
+ "metadata": {
332
+ "collapsed": false
333
+ }
334
+ },
335
+ {
336
+ "cell_type": "code",
337
+ "source": [
338
+ "class MultiLabelTrainer(Trainer):\n",
339
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
340
+ " labels = inputs.pop(\"labels\")\n",
341
+ " outputs = model(**inputs)\n",
342
+ " logits = outputs.logits\n",
343
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
344
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
345
+ " labels.float().view(-1, self.model.config.num_labels))\n",
346
+ " return (loss, outputs) if return_outputs else loss"
347
+ ],
348
+ "outputs": [],
349
+ "execution_count": 12,
350
+ "metadata": {
351
+ "datalore": {
352
+ "node_id": "itXWkbDw9sqbkMuDP84QoT",
353
+ "type": "CODE",
354
+ "hide_input_from_viewers": true,
355
+ "hide_output_from_viewers": true
356
+ },
357
+ "gather": {
358
+ "logged": 1706415294807
359
+ }
360
+ }
361
+ },
362
+ {
363
+ "cell_type": "code",
364
+ "source": [
365
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
366
+ ],
367
+ "outputs": [
368
+ {
369
+ "output_type": "stream",
370
+ "name": "stderr",
371
+ "text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
372
+ }
373
+ ],
374
+ "execution_count": 13,
375
+ "metadata": {
376
+ "datalore": {
377
+ "node_id": "ZQU7aW6TV45VmhHOQRzcnF",
378
+ "type": "CODE",
379
+ "hide_input_from_viewers": true,
380
+ "hide_output_from_viewers": true
381
+ },
382
+ "gather": {
383
+ "logged": 1706415296683
384
+ }
385
+ }
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "source": [
390
+ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
391
+ " y_pred = torch.from_numpy(y_pred)\n",
392
+ " y_true = torch.from_numpy(y_true)\n",
393
+ "\n",
394
+ " if sigmoid:\n",
395
+ " y_pred = y_pred.sigmoid()\n",
396
+ "\n",
397
+ " return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
398
+ ],
399
+ "outputs": [],
400
+ "execution_count": 14,
401
+ "metadata": {
402
+ "datalore": {
403
+ "node_id": "swhgyyyxoGL8HjnXJtMuSW",
404
+ "type": "CODE",
405
+ "hide_input_from_viewers": true,
406
+ "hide_output_from_viewers": true
407
+ },
408
+ "gather": {
409
+ "logged": 1706415296937
410
+ }
411
+ }
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "source": [
416
+ "def compute_metrics(eval_pred):\n",
417
+ " predictions, labels = eval_pred\n",
418
+ " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
419
+ ],
420
+ "outputs": [],
421
+ "execution_count": 15,
422
+ "metadata": {
423
+ "datalore": {
424
+ "node_id": "1Uq3HtkaBxtHNAnSwit5cI",
425
+ "type": "CODE",
426
+ "hide_input_from_viewers": true,
427
+ "hide_output_from_viewers": true
428
+ },
429
+ "gather": {
430
+ "logged": 1706415297280
431
+ }
432
+ }
433
+ },
434
+ {
435
+ "cell_type": "code",
436
+ "source": [
437
+ "args = TrainingArguments(\n",
438
+ " output_dir=\"vaers\",\n",
439
+ " evaluation_strategy=\"epoch\",\n",
440
+ " learning_rate=2e-5,\n",
441
+ " per_device_train_batch_size=BATCH_SIZE,\n",
442
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
443
+ " num_train_epochs=EPOCHS,\n",
444
+ " weight_decay=.01,\n",
445
+ " report_to=[\"wandb\"]\n",
446
+ ")"
447
+ ],
448
+ "outputs": [],
449
+ "execution_count": 16,
450
+ "metadata": {
451
+ "datalore": {
452
+ "node_id": "1iPZOTKPwSkTgX5dORqT89",
453
+ "type": "CODE",
454
+ "hide_input_from_viewers": true,
455
+ "hide_output_from_viewers": true
456
+ },
457
+ "gather": {
458
+ "logged": 1706415297551
459
+ }
460
+ }
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "source": [
465
+ "multi_label_trainer = MultiLabelTrainer(\n",
466
+ " model, \n",
467
+ " args, \n",
468
+ " train_dataset=ds_enc[\"train\"], \n",
469
+ " eval_dataset=ds_enc[\"test\"], \n",
470
+ " compute_metrics=compute_metrics, \n",
471
+ " tokenizer=tokenizer\n",
472
+ ")"
473
+ ],
474
+ "outputs": [
475
+ {
476
+ "output_type": "stream",
477
+ "name": "stderr",
478
+ "text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
479
+ }
480
+ ],
481
+ "execution_count": 17,
482
+ "metadata": {
483
+ "datalore": {
484
+ "node_id": "bnRkNvRYltLun6gCEgL7v0",
485
+ "type": "CODE",
486
+ "hide_input_from_viewers": true,
487
+ "hide_output_from_viewers": true
488
+ },
489
+ "gather": {
490
+ "logged": 1706415297795
491
+ }
492
+ }
493
+ },
494
+ {
495
+ "cell_type": "code",
496
+ "source": [
497
+ "multi_label_trainer.evaluate()"
498
+ ],
499
+ "outputs": [
500
+ {
501
+ "output_type": "display_data",
502
+ "data": {
503
+ "text/plain": "<IPython.core.display.HTML object>",
504
+ "text/html": "\n <div>\n \n <progress value='987' max='987' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [987/987 01:13]\n </div>\n "
505
+ },
506
+ "metadata": {}
507
+ },
508
+ {
509
+ "output_type": "stream",
510
+ "name": "stderr",
511
+ "text": "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
512
+ },
513
+ {
514
+ "output_type": "display_data",
515
+ "data": {
516
+ "text/plain": "<IPython.core.display.HTML object>",
517
+ "text/html": "Tracking run with wandb version 0.16.2"
518
+ },
519
+ "metadata": {}
520
+ },
521
+ {
522
+ "output_type": "display_data",
523
+ "data": {
524
+ "text/plain": "<IPython.core.display.HTML object>",
525
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_041615-nnw129w4</code>"
526
+ },
527
+ "metadata": {}
528
+ },
529
+ {
530
+ "output_type": "display_data",
531
+ "data": {
532
+ "text/plain": "<IPython.core.display.HTML object>",
533
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/nnw129w4' target=\"_blank\">grateful-shadow-2</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
534
+ },
535
+ "metadata": {}
536
+ },
537
+ {
538
+ "output_type": "display_data",
539
+ "data": {
540
+ "text/plain": "<IPython.core.display.HTML object>",
541
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
542
+ },
543
+ "metadata": {}
544
+ },
545
+ {
546
+ "output_type": "display_data",
547
+ "data": {
548
+ "text/plain": "<IPython.core.display.HTML object>",
549
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/nnw129w4' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/nnw129w4</a>"
550
+ },
551
+ "metadata": {}
552
+ },
553
+ {
554
+ "output_type": "execute_result",
555
+ "execution_count": 18,
556
+ "data": {
557
+ "text/plain": "{'eval_loss': 0.712559163570404,\n 'eval_accuracy_thresh': 0.36481693387031555,\n 'eval_runtime': 76.4156,\n 'eval_samples_per_second': 206.581,\n 'eval_steps_per_second': 12.916}"
558
+ },
559
+ "metadata": {}
560
+ }
561
+ ],
562
+ "execution_count": 18,
563
+ "metadata": {
564
+ "datalore": {
565
+ "node_id": "LO54PlDkWQdFrzV25FvduB",
566
+ "type": "CODE",
567
+ "hide_input_from_viewers": true,
568
+ "hide_output_from_viewers": true
569
+ },
570
+ "gather": {
571
+ "logged": 1706415378024
572
+ }
573
+ }
574
+ },
575
+ {
576
+ "cell_type": "code",
577
+ "source": [
578
+ "multi_label_trainer.train()"
579
+ ],
580
+ "outputs": [
581
+ {
582
+ "output_type": "display_data",
583
+ "data": {
584
+ "text/plain": "<IPython.core.display.HTML object>",
585
+ "text/html": "\n <div>\n \n <progress value='3001' max='4605' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [3001/4605 12:05 < 06:28, 4.13 it/s, Epoch 0.65/1]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
586
+ },
587
+ "metadata": {}
588
+ },
589
+ {
590
+ "output_type": "stream",
591
+ "name": "stderr",
592
+ "text": "Checkpoint destination directory vaers/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.2s\nCheckpoint destination directory vaers/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 13.4s\nCheckpoint destination directory vaers/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 13.0s\nCheckpoint destination directory vaers/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 11.6s\nCheckpoint destination directory vaers/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 14.6s\nCheckpoint destination directory vaers/checkpoint-3000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... "
593
+ }
594
+ ],
595
+ "execution_count": 19,
596
+ "metadata": {
597
+ "datalore": {
598
+ "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
599
+ "type": "CODE",
600
+ "hide_input_from_viewers": true,
601
+ "hide_output_from_viewers": true
602
+ },
603
+ "gather": {
604
+ "logged": 1706411445752
605
+ }
606
+ }
607
+ },
608
+ {
609
+ "cell_type": "markdown",
610
+ "source": [
611
+ "### Evaluation"
612
+ ],
613
+ "metadata": {
614
+ "collapsed": false
615
+ }
616
+ },
617
+ {
618
+ "cell_type": "markdown",
619
+ "source": [
620
+ "We instantiate a classifier `pipeline` and push it to CUDA."
621
+ ],
622
+ "metadata": {
623
+ "collapsed": false
624
+ }
625
+ },
626
+ {
627
+ "cell_type": "code",
628
+ "source": [
629
+ "classifier = pipeline(\"text-classification\", \n",
630
+ " model, \n",
631
+ " tokenizer=tokenizer, \n",
632
+ " device=\"cuda:0\")"
633
+ ],
634
+ "outputs": [],
635
+ "execution_count": null,
636
+ "metadata": {
637
+ "datalore": {
638
+ "node_id": "kHoUdBeqcyVXDSGv54C4aE",
639
+ "type": "CODE",
640
+ "hide_input_from_viewers": true,
641
+ "hide_output_from_viewers": true
642
+ },
643
+ "gather": {
644
+ "logged": 1706411459928
645
+ }
646
+ }
647
+ },
648
+ {
649
+ "cell_type": "markdown",
650
+ "source": [
651
+ "We use the same tokenizer used for training to tokenize/encode the validation set."
652
+ ],
653
+ "metadata": {
654
+ "collapsed": false
655
+ }
656
+ },
657
+ {
658
+ "cell_type": "code",
659
+ "source": [
660
+ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
661
+ " max_length=None, \n",
662
+ " padding='max_length', \n",
663
+ " return_token_type_ids=True, \n",
664
+ " truncation=True)"
665
+ ],
666
+ "outputs": [],
667
+ "execution_count": null,
668
+ "metadata": {
669
+ "datalore": {
670
+ "node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
671
+ "type": "CODE",
672
+ "hide_input_from_viewers": true,
673
+ "hide_output_from_viewers": true
674
+ },
675
+ "gather": {
676
+ "logged": 1706411523285
677
+ }
678
+ }
679
+ },
680
+ {
681
+ "cell_type": "markdown",
682
+ "source": [
683
+ "Once we've made the data loadable by putting it into a `DataLoader`, we "
684
+ ],
685
+ "metadata": {
686
+ "collapsed": false
687
+ }
688
+ },
689
+ {
690
+ "cell_type": "code",
691
+ "source": [
692
+ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
693
+ " torch.tensor(test_encodings['attention_mask']), \n",
694
+ " torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
695
+ " torch.tensor(test_encodings['token_type_ids']))\n",
696
+ "test_dataloader = torch.utils.data.DataLoader(test_data, \n",
697
+ " sampler=torch.utils.data.SequentialSampler(test_data), \n",
698
+ " batch_size=BATCH_SIZE)"
699
+ ],
700
+ "outputs": [],
701
+ "execution_count": null,
702
+ "metadata": {
703
+ "datalore": {
704
+ "node_id": "MWfGq2tTkJNzFiDoUPq2X7",
705
+ "type": "CODE",
706
+ "hide_input_from_viewers": true,
707
+ "hide_output_from_viewers": true
708
+ },
709
+ "gather": {
710
+ "logged": 1706411543379
711
+ }
712
+ }
713
+ },
714
+ {
715
+ "cell_type": "code",
716
+ "source": [
717
+ "model.eval()\n",
718
+ "\n",
719
+ "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
720
+ "\n",
721
+ "for i, batch in enumerate(test_dataloader):\n",
722
+ " batch = tuple(t.to(device) for t in batch)\n",
723
+ " \n",
724
+ " # Unpack the inputs from our dataloader\n",
725
+ " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
726
+ " \n",
727
+ " with torch.no_grad():\n",
728
+ " outs = model(b_input_ids, attention_mask=b_input_mask)\n",
729
+ " b_logit_pred = outs[0]\n",
730
+ " pred_label = torch.sigmoid(b_logit_pred)\n",
731
+ "\n",
732
+ " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
733
+ " pred_label = pred_label.to('cpu').numpy()\n",
734
+ " b_labels = b_labels.to('cpu').numpy()\n",
735
+ "\n",
736
+ " tokenized_texts.append(b_input_ids)\n",
737
+ " logit_preds.append(b_logit_pred)\n",
738
+ " true_labels.append(b_labels)\n",
739
+ " pred_labels.append(pred_label)\n",
740
+ "\n",
741
+ "# Flatten outputs\n",
742
+ "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
743
+ "pred_labels = [item for sublist in pred_labels for item in sublist]\n",
744
+ "true_labels = [item for sublist in true_labels for item in sublist]\n",
745
+ "\n",
746
+ "# Converting flattened binary values to boolean values\n",
747
+ "true_bools = [tl == 1 for tl in true_labels]\n",
748
+ "pred_bools = [pl > 0.50 for pl in pred_labels] "
749
+ ],
750
+ "outputs": [],
751
+ "execution_count": null,
752
+ "metadata": {
753
+ "datalore": {
754
+ "node_id": "1SJCSrQTRCexFCNCIyRrzL",
755
+ "type": "CODE",
756
+ "hide_input_from_viewers": true,
757
+ "hide_output_from_viewers": true
758
+ },
759
+ "gather": {
760
+ "logged": 1706411587843
761
+ }
762
+ }
763
+ },
764
+ {
765
+ "cell_type": "markdown",
766
+ "source": [
767
+ "We create a classification report:"
768
+ ],
769
+ "metadata": {
770
+ "collapsed": false
771
+ }
772
+ },
773
+ {
774
+ "cell_type": "code",
775
+ "source": [
776
+ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
777
+ "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
778
+ "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
779
+ "print(clf_report)"
780
+ ],
781
+ "outputs": [],
782
+ "execution_count": null,
783
+ "metadata": {
784
+ "datalore": {
785
+ "node_id": "eBprrgF086mznPbPVBpOLS",
786
+ "type": "CODE",
787
+ "hide_input_from_viewers": true,
788
+ "hide_output_from_viewers": true
789
+ },
790
+ "gather": {
791
+ "logged": 1706411588249
792
+ }
793
+ }
794
+ },
795
+ {
796
+ "cell_type": "markdown",
797
+ "source": [
798
+ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
799
+ ],
800
+ "metadata": {
801
+ "collapsed": false
802
+ }
803
+ },
804
+ {
805
+ "cell_type": "code",
806
+ "source": [
807
+ "# Creating a map of class names from class numbers\n",
808
+ "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
809
+ ],
810
+ "outputs": [],
811
+ "execution_count": null,
812
+ "metadata": {
813
+ "datalore": {
814
+ "node_id": "yELHY0IEwMlMw3x6e7hoD1",
815
+ "type": "CODE",
816
+ "hide_input_from_viewers": true,
817
+ "hide_output_from_viewers": true
818
+ },
819
+ "gather": {
820
+ "logged": 1706411588638
821
+ }
822
+ }
823
+ },
824
+ {
825
+ "cell_type": "code",
826
+ "source": [
827
+ "true_label_idxs, pred_label_idxs = [], []\n",
828
+ "\n",
829
+ "for vals in true_bools:\n",
830
+ " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
831
+ "for vals in pred_bools:\n",
832
+ " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
833
+ ],
834
+ "outputs": [],
835
+ "execution_count": null,
836
+ "metadata": {
837
+ "datalore": {
838
+ "node_id": "jH0S35dDteUch01sa6me6e",
839
+ "type": "CODE",
840
+ "hide_input_from_viewers": true,
841
+ "hide_output_from_viewers": true
842
+ },
843
+ "gather": {
844
+ "logged": 1706411589004
845
+ }
846
+ }
847
+ },
848
+ {
849
+ "cell_type": "code",
850
+ "source": [
851
+ "true_label_texts, pred_label_texts = [], []\n",
852
+ "\n",
853
+ "for vals in true_label_idxs:\n",
854
+ " if vals:\n",
855
+ " true_label_texts.append([idx2label[val] for val in vals])\n",
856
+ " else:\n",
857
+ " true_label_texts.append(vals)\n",
858
+ "\n",
859
+ "for vals in pred_label_idxs:\n",
860
+ " if vals:\n",
861
+ " pred_label_texts.append([idx2label[val] for val in vals])\n",
862
+ " else:\n",
863
+ " pred_label_texts.append(vals)"
864
+ ],
865
+ "outputs": [],
866
+ "execution_count": null,
867
+ "metadata": {
868
+ "datalore": {
869
+ "node_id": "h4vHL8XdGpayZ6xLGJUF6F",
870
+ "type": "CODE",
871
+ "hide_input_from_viewers": true,
872
+ "hide_output_from_viewers": true
873
+ },
874
+ "gather": {
875
+ "logged": 1706411589301
876
+ }
877
+ }
878
+ },
879
+ {
880
+ "cell_type": "code",
881
+ "source": [
882
+ "symptom_texts = [tokenizer.decode(text,\n",
883
+ " skip_special_tokens=True,\n",
884
+ " clean_up_tokenization_spaces=False) for text in tokenized_texts]"
885
+ ],
886
+ "outputs": [],
887
+ "execution_count": null,
888
+ "metadata": {
889
+ "datalore": {
890
+ "node_id": "SxUmVHfQISEeptg1SawOmB",
891
+ "type": "CODE",
892
+ "hide_input_from_viewers": true,
893
+ "hide_output_from_viewers": true
894
+ },
895
+ "gather": {
896
+ "logged": 1706411591952
897
+ }
898
+ }
899
+ },
900
+ {
901
+ "cell_type": "code",
902
+ "source": [
903
+ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
904
+ " 'true_labels': true_label_texts, \n",
905
+ " 'pred_labels':pred_label_texts})\n",
906
+ "comparisons_df.to_csv('comparisons.csv')\n",
907
+ "comparisons_df"
908
+ ],
909
+ "outputs": [],
910
+ "execution_count": null,
911
+ "metadata": {
912
+ "datalore": {
913
+ "node_id": "BxFNigNGRLTOqraI55BPSH",
914
+ "type": "CODE",
915
+ "hide_input_from_viewers": true,
916
+ "hide_output_from_viewers": true
917
+ },
918
+ "gather": {
919
+ "logged": 1706411592512
920
+ }
921
+ }
922
+ },
923
+ {
924
+ "cell_type": "markdown",
925
+ "source": [
926
+ "### Shapley analysis"
927
+ ],
928
+ "metadata": {
929
+ "collapsed": false
930
+ }
931
+ },
932
+ {
933
+ "cell_type": "code",
934
+ "source": [
935
+ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
936
+ ],
937
+ "outputs": [],
938
+ "execution_count": null,
939
+ "metadata": {
940
+ "datalore": {
941
+ "node_id": "OpdZcoenX2HwzLdai7K5UA",
942
+ "type": "CODE",
943
+ "hide_input_from_viewers": true,
944
+ "hide_output_from_viewers": true
945
+ },
946
+ "gather": {
947
+ "logged": 1706415109071
948
+ }
949
+ }
950
+ },
951
+ {
952
+ "cell_type": "markdown",
953
+ "source": [
954
+ "#### Sampling correct predictions\n",
955
+ "\n",
956
+ "First, let's look at some correct predictions of deaths:"
957
+ ],
958
+ "metadata": {
959
+ "nteract": {
960
+ "transient": {
961
+ "deleting": false
962
+ }
963
+ }
964
+ }
965
+ },
966
+ {
967
+ "cell_type": "code",
968
+ "source": [
969
+ "correct_death_predictions = comparisons_df[comparisons_df['true_labels'].astype(str) == \"['DIED']\"]"
970
+ ],
971
+ "outputs": [],
972
+ "execution_count": null,
973
+ "metadata": {
974
+ "jupyter": {
975
+ "source_hidden": false,
976
+ "outputs_hidden": false
977
+ },
978
+ "nteract": {
979
+ "transient": {
980
+ "deleting": false
981
+ }
982
+ },
983
+ "gather": {
984
+ "logged": 1706414973990
985
+ }
986
+ }
987
+ },
988
+ {
989
+ "cell_type": "code",
990
+ "source": [
991
+ "texts = [i[:512] for i in correct_death_predictions.sample(n=6).symptom_text]\n",
992
+ "idxs = [i for i in range(len(texts))]\n",
993
+ "\n",
994
+ "d_s = Dataset(Table.from_arrays([idxs, texts], names=[\"idx\", \"texts\"]))"
995
+ ],
996
+ "outputs": [],
997
+ "execution_count": null,
998
+ "metadata": {
999
+ "jupyter": {
1000
+ "source_hidden": false,
1001
+ "outputs_hidden": false
1002
+ },
1003
+ "nteract": {
1004
+ "transient": {
1005
+ "deleting": false
1006
+ }
1007
+ },
1008
+ "gather": {
1009
+ "logged": 1706415114683
1010
+ }
1011
+ }
1012
+ },
1013
+ {
1014
+ "cell_type": "code",
1015
+ "source": [
1016
+ "shap_values = explainer(d_s[\"texts\"])"
1017
+ ],
1018
+ "outputs": [],
1019
+ "execution_count": null,
1020
+ "metadata": {
1021
+ "jupyter": {
1022
+ "source_hidden": false,
1023
+ "outputs_hidden": false
1024
+ },
1025
+ "nteract": {
1026
+ "transient": {
1027
+ "deleting": false
1028
+ }
1029
+ },
1030
+ "gather": {
1031
+ "logged": 1706415129229
1032
+ }
1033
+ }
1034
+ },
1035
+ {
1036
+ "cell_type": "code",
1037
+ "source": [
1038
+ "shap.plots.text(shap_values)"
1039
+ ],
1040
+ "outputs": [],
1041
+ "execution_count": null,
1042
+ "metadata": {
1043
+ "jupyter": {
1044
+ "source_hidden": false,
1045
+ "outputs_hidden": false
1046
+ },
1047
+ "nteract": {
1048
+ "transient": {
1049
+ "deleting": false
1050
+ }
1051
+ },
1052
+ "gather": {
1053
+ "logged": 1706415151494
1054
+ }
1055
+ }
1056
+ },
1057
+ {
1058
+ "cell_type": "code",
1059
+ "source": [],
1060
+ "outputs": [],
1061
+ "execution_count": null,
1062
+ "metadata": {
1063
+ "jupyter": {
1064
+ "source_hidden": false,
1065
+ "outputs_hidden": false
1066
+ },
1067
+ "nteract": {
1068
+ "transient": {
1069
+ "deleting": false
1070
+ }
1071
+ }
1072
+ }
1073
+ }
1074
+ ],
1075
+ "metadata": {
1076
+ "kernelspec": {
1077
+ "name": "python3",
1078
+ "language": "python",
1079
+ "display_name": "Python 3 (ipykernel)"
1080
+ },
1081
+ "datalore": {
1082
+ "computation_mode": "JUPYTER",
1083
+ "package_manager": "pip",
1084
+ "base_environment": "default",
1085
+ "packages": [
1086
+ {
1087
+ "name": "datasets",
1088
+ "version": "2.16.1",
1089
+ "source": "PIP"
1090
+ },
1091
+ {
1092
+ "name": "torch",
1093
+ "version": "2.1.2",
1094
+ "source": "PIP"
1095
+ },
1096
+ {
1097
+ "name": "accelerate",
1098
+ "version": "0.26.1",
1099
+ "source": "PIP"
1100
+ }
1101
+ ],
1102
+ "report_row_ids": [
1103
+ "un8W7ez7ZwoGb5Co6nydEV",
1104
+ "40nN9Hvgi1clHNV5RAemI5",
1105
+ "TgRD90H5NSPpKS41OeXI1w",
1106
+ "ZOm5BfUs3h1EGLaUkBGeEB",
1107
+ "kOP0CZWNSk6vqE3wkPp7Vc",
1108
+ "W4PWcOu2O2pRaZyoE2W80h",
1109
+ "RolbOnQLIftk0vy9mIcz5M",
1110
+ "8OPhUgbaNJmOdiq5D3a6vK",
1111
+ "5Qrt3jSvSrpK6Ne1hS6shL",
1112
+ "hTq7nFUrovN5Ao4u6dIYWZ",
1113
+ "I8WNZLpJ1DVP2wiCW7YBIB",
1114
+ "SawhU3I9BewSE1XBPstpNJ",
1115
+ "80EtLEl2FIE4FqbWnUD3nT"
1116
+ ],
1117
+ "version": 3
1118
+ },
1119
+ "microsoft": {
1120
+ "ms_spell_check": {
1121
+ "ms_spell_check_language": "en"
1122
+ },
1123
+ "host": {
1124
+ "AzureML": {
1125
+ "notebookHasBeenCompleted": true
1126
+ }
1127
+ }
1128
+ },
1129
+ "language_info": {
1130
+ "name": "python",
1131
+ "version": "3.8.5",
1132
+ "mimetype": "text/x-python",
1133
+ "codemirror_mode": {
1134
+ "name": "ipython",
1135
+ "version": 3
1136
+ },
1137
+ "pygments_lexer": "ipython3",
1138
+ "nbconvert_exporter": "python",
1139
+ "file_extension": ".py"
1140
+ },
1141
+ "nteract": {
1142
+ "version": "nteract-front-end@1.0.0"
1143
+ }
1144
+ },
1145
+ "nbformat": 4,
1146
+ "nbformat_minor": 4
1147
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-15-7-36Z.ipynb ADDED
@@ -0,0 +1,1452 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
8
+ "\n",
9
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "nteract": {
17
+ "transient": {
18
+ "deleting": false
19
+ }
20
+ },
21
+ "tags": []
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "# %pip install accelerate -U"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 2,
31
+ "metadata": {
32
+ "nteract": {
33
+ "transient": {
34
+ "deleting": false
35
+ }
36
+ }
37
+ },
38
+ "outputs": [],
39
+ "source": [
40
+ "# %pip install transformers datasets shap watermark wandb"
41
+ ]
42
+ },
43
+ {
44
+ "cell_type": "code",
45
+ "execution_count": 3,
46
+ "metadata": {
47
+ "datalore": {
48
+ "hide_input_from_viewers": false,
49
+ "hide_output_from_viewers": false,
50
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
51
+ "report_properties": {
52
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
53
+ },
54
+ "type": "CODE"
55
+ },
56
+ "gather": {
57
+ "logged": 1706449625034
58
+ },
59
+ "tags": []
60
+ },
61
+ "outputs": [
62
+ {
63
+ "name": "stderr",
64
+ "output_type": "stream",
65
+ "text": [
66
+ "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
67
+ " from .autonotebook import tqdm as notebook_tqdm\n",
68
+ "2024-01-28 14:18:31.729214: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
69
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
70
+ "2024-01-28 14:18:32.746966: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
71
+ "2024-01-28 14:18:32.747096: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
72
+ "2024-01-28 14:18:32.747111: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
73
+ ]
74
+ }
75
+ ],
76
+ "source": [
77
+ "import pandas as pd\n",
78
+ "import numpy as np\n",
79
+ "import torch\n",
80
+ "import os\n",
81
+ "from typing import List\n",
82
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
83
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
84
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
85
+ "from pyarrow import Table\n",
86
+ "import shap\n",
87
+ "import wandb\n",
88
+ "\n",
89
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
90
+ "\n",
91
+ "%load_ext watermark"
92
+ ]
93
+ },
94
+ {
95
+ "cell_type": "code",
96
+ "execution_count": 4,
97
+ "metadata": {
98
+ "collapsed": false,
99
+ "gather": {
100
+ "logged": 1706449721319
101
+ },
102
+ "jupyter": {
103
+ "outputs_hidden": false
104
+ }
105
+ },
106
+ "outputs": [],
107
+ "source": [
108
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
109
+ "\n",
110
+ "SEED: int = 42\n",
111
+ "\n",
112
+ "BATCH_SIZE: int = 16\n",
113
+ "EPOCHS: int = 3\n",
114
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
115
+ "\n",
116
+ "CLASS_NAMES: List[str] = [\"DIED\",\n",
117
+ " \"ER_VISIT\",\n",
118
+ " \"HOSPITAL\",\n",
119
+ " \"OFC_VISIT\",\n",
120
+ " #\"X_STAY\", # pruned\n",
121
+ " #\"DISABLE\", # pruned\n",
122
+ " #\"D_PRESENTED\" # pruned\n",
123
+ " ]\n",
124
+ "\n",
125
+ "\n",
126
+ "\n",
127
+ "\n",
128
+ "# WandB configuration\n",
129
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
130
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
131
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "code",
136
+ "execution_count": 5,
137
+ "metadata": {
138
+ "collapsed": false,
139
+ "jupyter": {
140
+ "outputs_hidden": false
141
+ }
142
+ },
143
+ "outputs": [
144
+ {
145
+ "name": "stdout",
146
+ "output_type": "stream",
147
+ "text": [
148
+ "torch : 1.12.0\n",
149
+ "pandas : 2.0.2\n",
150
+ "numpy : 1.23.5\n",
151
+ "shap : 0.44.1\n",
152
+ "re : 2.2.1\n",
153
+ "wandb : 0.16.2\n",
154
+ "logging: 0.5.1.2\n",
155
+ "\n"
156
+ ]
157
+ }
158
+ ],
159
+ "source": [
160
+ "%watermark --iversion"
161
+ ]
162
+ },
163
+ {
164
+ "cell_type": "code",
165
+ "execution_count": 6,
166
+ "metadata": {
167
+ "datalore": {
168
+ "hide_input_from_viewers": true,
169
+ "hide_output_from_viewers": true,
170
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
171
+ "type": "CODE"
172
+ }
173
+ },
174
+ "outputs": [
175
+ {
176
+ "name": "stdout",
177
+ "output_type": "stream",
178
+ "text": [
179
+ "Sun Jan 28 14:18:35 2024 \n",
180
+ "+---------------------------------------------------------------------------------------+\n",
181
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
182
+ "|-----------------------------------------+----------------------+----------------------+\n",
183
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
184
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
185
+ "| | | MIG M. |\n",
186
+ "|=========================================+======================+======================|\n",
187
+ "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
188
+ "| N/A 29C P0 27W / 250W | 4MiB / 16384MiB | 0% Default |\n",
189
+ "| | | N/A |\n",
190
+ "+-----------------------------------------+----------------------+----------------------+\n",
191
+ "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
192
+ "| N/A 29C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\n",
193
+ "| | | N/A |\n",
194
+ "+-----------------------------------------+----------------------+----------------------+\n",
195
+ " \n",
196
+ "+---------------------------------------------------------------------------------------+\n",
197
+ "| Processes: |\n",
198
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
199
+ "| ID ID Usage |\n",
200
+ "|=======================================================================================|\n",
201
+ "| No running processes found |\n",
202
+ "+---------------------------------------------------------------------------------------+\n"
203
+ ]
204
+ }
205
+ ],
206
+ "source": [
207
+ "!nvidia-smi"
208
+ ]
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "metadata": {
213
+ "datalore": {
214
+ "hide_input_from_viewers": false,
215
+ "hide_output_from_viewers": false,
216
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
217
+ "report_properties": {
218
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
219
+ },
220
+ "type": "MD"
221
+ }
222
+ },
223
+ "source": [
224
+ "## Loading the data set"
225
+ ]
226
+ },
227
+ {
228
+ "cell_type": "code",
229
+ "execution_count": 7,
230
+ "metadata": {
231
+ "collapsed": false,
232
+ "gather": {
233
+ "logged": 1706449040507
234
+ },
235
+ "jupyter": {
236
+ "outputs_hidden": false
237
+ }
238
+ },
239
+ "outputs": [],
240
+ "source": [
241
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
242
+ ]
243
+ },
244
+ {
245
+ "cell_type": "code",
246
+ "execution_count": 8,
247
+ "metadata": {
248
+ "collapsed": false,
249
+ "gather": {
250
+ "logged": 1706449044205
251
+ },
252
+ "jupyter": {
253
+ "outputs_hidden": false,
254
+ "source_hidden": false
255
+ },
256
+ "nteract": {
257
+ "transient": {
258
+ "deleting": false
259
+ }
260
+ }
261
+ },
262
+ "outputs": [
263
+ {
264
+ "data": {
265
+ "text/plain": [
266
+ "DatasetDict({\n",
267
+ " train: Dataset({\n",
268
+ " features: ['id', 'text', 'labels'],\n",
269
+ " num_rows: 1270444\n",
270
+ " })\n",
271
+ " test: Dataset({\n",
272
+ " features: ['id', 'text', 'labels'],\n",
273
+ " num_rows: 272238\n",
274
+ " })\n",
275
+ " val: Dataset({\n",
276
+ " features: ['id', 'text', 'labels'],\n",
277
+ " num_rows: 272238\n",
278
+ " })\n",
279
+ "})"
280
+ ]
281
+ },
282
+ "execution_count": 8,
283
+ "metadata": {},
284
+ "output_type": "execute_result"
285
+ }
286
+ ],
287
+ "source": [
288
+ "dataset"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 9,
294
+ "metadata": {},
295
+ "outputs": [],
296
+ "source": [
297
+ "SUBSAMPLING: float = 0.1"
298
+ ]
299
+ },
300
+ {
301
+ "cell_type": "code",
302
+ "execution_count": 10,
303
+ "metadata": {
304
+ "collapsed": false,
305
+ "gather": {
306
+ "logged": 1706449378281
307
+ },
308
+ "jupyter": {
309
+ "outputs_hidden": false,
310
+ "source_hidden": false
311
+ },
312
+ "nteract": {
313
+ "transient": {
314
+ "deleting": false
315
+ }
316
+ }
317
+ },
318
+ "outputs": [],
319
+ "source": [
320
+ "def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n",
321
+ " res = DatasetDict()\n",
322
+ "\n",
323
+ " res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n",
324
+ " res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n",
325
+ " res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n",
326
+ " \n",
327
+ " return res"
328
+ ]
329
+ },
330
+ {
331
+ "cell_type": "code",
332
+ "execution_count": 11,
333
+ "metadata": {
334
+ "collapsed": false,
335
+ "gather": {
336
+ "logged": 1706449384162
337
+ },
338
+ "jupyter": {
339
+ "outputs_hidden": false,
340
+ "source_hidden": false
341
+ },
342
+ "nteract": {
343
+ "transient": {
344
+ "deleting": false
345
+ }
346
+ }
347
+ },
348
+ "outputs": [],
349
+ "source": [
350
+ "dataset = minisample(dataset, SUBSAMPLING)"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 12,
356
+ "metadata": {
357
+ "collapsed": false,
358
+ "gather": {
359
+ "logged": 1706449387981
360
+ },
361
+ "jupyter": {
362
+ "outputs_hidden": false,
363
+ "source_hidden": false
364
+ },
365
+ "nteract": {
366
+ "transient": {
367
+ "deleting": false
368
+ }
369
+ }
370
+ },
371
+ "outputs": [
372
+ {
373
+ "data": {
374
+ "text/plain": [
375
+ "DatasetDict({\n",
376
+ " train: Dataset({\n",
377
+ " features: ['id', 'text', 'labels'],\n",
378
+ " num_rows: 127044\n",
379
+ " })\n",
380
+ " test: Dataset({\n",
381
+ " features: ['id', 'text', 'labels'],\n",
382
+ " num_rows: 27224\n",
383
+ " })\n",
384
+ " val: Dataset({\n",
385
+ " features: ['id', 'text', 'labels'],\n",
386
+ " num_rows: 27224\n",
387
+ " })\n",
388
+ "})"
389
+ ]
390
+ },
391
+ "execution_count": 12,
392
+ "metadata": {},
393
+ "output_type": "execute_result"
394
+ }
395
+ ],
396
+ "source": [
397
+ "dataset"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "markdown",
402
+ "metadata": {
403
+ "nteract": {
404
+ "transient": {
405
+ "deleting": false
406
+ }
407
+ }
408
+ },
409
+ "source": [
410
+ "We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
411
+ ]
412
+ },
413
+ {
414
+ "cell_type": "code",
415
+ "execution_count": 13,
416
+ "metadata": {
417
+ "collapsed": false,
418
+ "gather": {
419
+ "logged": 1706449443055
420
+ },
421
+ "jupyter": {
422
+ "outputs_hidden": false,
423
+ "source_hidden": false
424
+ },
425
+ "nteract": {
426
+ "transient": {
427
+ "deleting": false
428
+ }
429
+ }
430
+ },
431
+ "outputs": [],
432
+ "source": [
433
+ "ds = DatasetDict()\n",
434
+ "\n",
435
+ "for i in [\"test\", \"train\", \"val\"]:\n",
436
+ " tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
437
+ " ds[i] = Dataset(tab)\n",
438
+ "\n",
439
+ "dataset = ds"
440
+ ]
441
+ },
442
+ {
443
+ "cell_type": "markdown",
444
+ "metadata": {},
445
+ "source": [
446
+ "### Tokenisation and encoding"
447
+ ]
448
+ },
449
+ {
450
+ "cell_type": "code",
451
+ "execution_count": 14,
452
+ "metadata": {
453
+ "datalore": {
454
+ "hide_input_from_viewers": true,
455
+ "hide_output_from_viewers": true,
456
+ "node_id": "I7n646PIscsUZRoHu6m7zm",
457
+ "type": "CODE"
458
+ },
459
+ "gather": {
460
+ "logged": 1706449638377
461
+ }
462
+ },
463
+ "outputs": [],
464
+ "source": [
465
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
466
+ ]
467
+ },
468
+ {
469
+ "cell_type": "code",
470
+ "execution_count": 15,
471
+ "metadata": {
472
+ "datalore": {
473
+ "hide_input_from_viewers": true,
474
+ "hide_output_from_viewers": true,
475
+ "node_id": "QBLOSI0yVIslV7v7qX9ZC3",
476
+ "type": "CODE"
477
+ },
478
+ "gather": {
479
+ "logged": 1706449642580
480
+ }
481
+ },
482
+ "outputs": [],
483
+ "source": [
484
+ "def tokenize_and_encode(examples):\n",
485
+ " return tokenizer(examples[\"text\"], truncation=True)"
486
+ ]
487
+ },
488
+ {
489
+ "cell_type": "code",
490
+ "execution_count": 16,
491
+ "metadata": {
492
+ "datalore": {
493
+ "hide_input_from_viewers": true,
494
+ "hide_output_from_viewers": true,
495
+ "node_id": "slHeNysZOX9uWS9PB7jFDb",
496
+ "type": "CODE"
497
+ },
498
+ "gather": {
499
+ "logged": 1706449721161
500
+ }
501
+ },
502
+ "outputs": [
503
+ {
504
+ "name": "stderr",
505
+ "output_type": "stream",
506
+ "text": [
507
+ "Map: 100%|██████████| 27224/27224 [00:10<00:00, 2638.52 examples/s]\n",
508
+ "Map: 100%|██████████| 127044/127044 [00:48<00:00, 2633.40 examples/s]\n",
509
+ "Map: 100%|██████████| 27224/27224 [00:10<00:00, 2613.19 examples/s]\n"
510
+ ]
511
+ }
512
+ ],
513
+ "source": [
514
+ "cols = dataset[\"train\"].column_names\n",
515
+ "cols.remove(\"labels\")\n",
516
+ "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
517
+ ]
518
+ },
519
+ {
520
+ "cell_type": "markdown",
521
+ "metadata": {},
522
+ "source": [
523
+ "### Training"
524
+ ]
525
+ },
526
+ {
527
+ "cell_type": "code",
528
+ "execution_count": 17,
529
+ "metadata": {
530
+ "datalore": {
531
+ "hide_input_from_viewers": true,
532
+ "hide_output_from_viewers": true,
533
+ "node_id": "itXWkbDw9sqbkMuDP84QoT",
534
+ "type": "CODE"
535
+ },
536
+ "gather": {
537
+ "logged": 1706449743072
538
+ }
539
+ },
540
+ "outputs": [],
541
+ "source": [
542
+ "class MultiLabelTrainer(Trainer):\n",
543
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
544
+ " labels = inputs.pop(\"labels\")\n",
545
+ " outputs = model(**inputs)\n",
546
+ " logits = outputs.logits\n",
547
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
548
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
549
+ " labels.float().view(-1, self.model.config.num_labels))\n",
550
+ " return (loss, outputs) if return_outputs else loss"
551
+ ]
552
+ },
553
+ {
554
+ "cell_type": "code",
555
+ "execution_count": 18,
556
+ "metadata": {
557
+ "datalore": {
558
+ "hide_input_from_viewers": true,
559
+ "hide_output_from_viewers": true,
560
+ "node_id": "ZQU7aW6TV45VmhHOQRzcnF",
561
+ "type": "CODE"
562
+ },
563
+ "gather": {
564
+ "logged": 1706449761205
565
+ }
566
+ },
567
+ "outputs": [
568
+ {
569
+ "name": "stderr",
570
+ "output_type": "stream",
571
+ "text": [
572
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
573
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
574
+ ]
575
+ }
576
+ ],
577
+ "source": [
578
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
579
+ ]
580
+ },
581
+ {
582
+ "cell_type": "code",
583
+ "execution_count": 19,
584
+ "metadata": {
585
+ "datalore": {
586
+ "hide_input_from_viewers": true,
587
+ "hide_output_from_viewers": true,
588
+ "node_id": "swhgyyyxoGL8HjnXJtMuSW",
589
+ "type": "CODE"
590
+ },
591
+ "gather": {
592
+ "logged": 1706449761541
593
+ }
594
+ },
595
+ "outputs": [],
596
+ "source": [
597
+ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
598
+ " y_pred = torch.from_numpy(y_pred)\n",
599
+ " y_true = torch.from_numpy(y_true)\n",
600
+ "\n",
601
+ " if sigmoid:\n",
602
+ " y_pred = y_pred.sigmoid()\n",
603
+ "\n",
604
+ " return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
605
+ ]
606
+ },
607
+ {
608
+ "cell_type": "code",
609
+ "execution_count": 20,
610
+ "metadata": {
611
+ "datalore": {
612
+ "hide_input_from_viewers": true,
613
+ "hide_output_from_viewers": true,
614
+ "node_id": "1Uq3HtkaBxtHNAnSwit5cI",
615
+ "type": "CODE"
616
+ },
617
+ "gather": {
618
+ "logged": 1706449761720
619
+ }
620
+ },
621
+ "outputs": [],
622
+ "source": [
623
+ "def compute_metrics(eval_pred):\n",
624
+ " predictions, labels = eval_pred\n",
625
+ " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "code",
630
+ "execution_count": 21,
631
+ "metadata": {
632
+ "datalore": {
633
+ "hide_input_from_viewers": true,
634
+ "hide_output_from_viewers": true,
635
+ "node_id": "1iPZOTKPwSkTgX5dORqT89",
636
+ "type": "CODE"
637
+ },
638
+ "gather": {
639
+ "logged": 1706449761893
640
+ }
641
+ },
642
+ "outputs": [],
643
+ "source": [
644
+ "args = TrainingArguments(\n",
645
+ " output_dir=\"vaers\",\n",
646
+ " evaluation_strategy=\"epoch\",\n",
647
+ " learning_rate=2e-5,\n",
648
+ " per_device_train_batch_size=BATCH_SIZE,\n",
649
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
650
+ " num_train_epochs=EPOCHS,\n",
651
+ " weight_decay=.01,\n",
652
+ " logging_steps=1,\n",
653
+ " run_name=f\"daedra-training\",\n",
654
+ " report_to=[\"wandb\"]\n",
655
+ ")"
656
+ ]
657
+ },
658
+ {
659
+ "cell_type": "code",
660
+ "execution_count": 22,
661
+ "metadata": {
662
+ "datalore": {
663
+ "hide_input_from_viewers": true,
664
+ "hide_output_from_viewers": true,
665
+ "node_id": "bnRkNvRYltLun6gCEgL7v0",
666
+ "type": "CODE"
667
+ },
668
+ "gather": {
669
+ "logged": 1706449769103
670
+ }
671
+ },
672
+ "outputs": [],
673
+ "source": [
674
+ "multi_label_trainer = MultiLabelTrainer(\n",
675
+ " model, \n",
676
+ " args, \n",
677
+ " train_dataset=ds_enc[\"train\"], \n",
678
+ " eval_dataset=ds_enc[\"test\"], \n",
679
+ " compute_metrics=compute_metrics, \n",
680
+ " tokenizer=tokenizer\n",
681
+ ")"
682
+ ]
683
+ },
684
+ {
685
+ "cell_type": "code",
686
+ "execution_count": 23,
687
+ "metadata": {
688
+ "datalore": {
689
+ "hide_input_from_viewers": true,
690
+ "hide_output_from_viewers": true,
691
+ "node_id": "LO54PlDkWQdFrzV25FvduB",
692
+ "type": "CODE"
693
+ },
694
+ "gather": {
695
+ "logged": 1706449880674
696
+ }
697
+ },
698
+ "outputs": [
699
+ {
700
+ "name": "stderr",
701
+ "output_type": "stream",
702
+ "text": [
703
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
704
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
705
+ ]
706
+ },
707
+ {
708
+ "data": {
709
+ "text/html": [
710
+ "Tracking run with wandb version 0.16.2"
711
+ ],
712
+ "text/plain": [
713
+ "<IPython.core.display.HTML object>"
714
+ ]
715
+ },
716
+ "metadata": {},
717
+ "output_type": "display_data"
718
+ },
719
+ {
720
+ "data": {
721
+ "text/html": [
722
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141956-9lniqjvz</code>"
723
+ ],
724
+ "text/plain": [
725
+ "<IPython.core.display.HTML object>"
726
+ ]
727
+ },
728
+ "metadata": {},
729
+ "output_type": "display_data"
730
+ },
731
+ {
732
+ "data": {
733
+ "text/html": [
734
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
735
+ ],
736
+ "text/plain": [
737
+ "<IPython.core.display.HTML object>"
738
+ ]
739
+ },
740
+ "metadata": {},
741
+ "output_type": "display_data"
742
+ },
743
+ {
744
+ "data": {
745
+ "text/html": [
746
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
747
+ ],
748
+ "text/plain": [
749
+ "<IPython.core.display.HTML object>"
750
+ ]
751
+ },
752
+ "metadata": {},
753
+ "output_type": "display_data"
754
+ },
755
+ {
756
+ "data": {
757
+ "text/html": [
758
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a>"
759
+ ],
760
+ "text/plain": [
761
+ "<IPython.core.display.HTML object>"
762
+ ]
763
+ },
764
+ "metadata": {},
765
+ "output_type": "display_data"
766
+ },
767
+ {
768
+ "data": {
769
+ "text/html": [
770
+ "Finishing last run (ID:9lniqjvz) before initializing another..."
771
+ ],
772
+ "text/plain": [
773
+ "<IPython.core.display.HTML object>"
774
+ ]
775
+ },
776
+ "metadata": {},
777
+ "output_type": "display_data"
778
+ },
779
+ {
780
+ "data": {
781
+ "text/html": [
782
+ " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
783
+ ],
784
+ "text/plain": [
785
+ "<IPython.core.display.HTML object>"
786
+ ]
787
+ },
788
+ "metadata": {},
789
+ "output_type": "display_data"
790
+ },
791
+ {
792
+ "data": {
793
+ "text/html": [
794
+ "Find logs at: <code>./wandb/run-20240128_141956-9lniqjvz/logs</code>"
795
+ ],
796
+ "text/plain": [
797
+ "<IPython.core.display.HTML object>"
798
+ ]
799
+ },
800
+ "metadata": {},
801
+ "output_type": "display_data"
802
+ },
803
+ {
804
+ "data": {
805
+ "text/html": [
806
+ "Successfully finished last run (ID:9lniqjvz). Initializing new run:<br/>"
807
+ ],
808
+ "text/plain": [
809
+ "<IPython.core.display.HTML object>"
810
+ ]
811
+ },
812
+ "metadata": {},
813
+ "output_type": "display_data"
814
+ },
815
+ {
816
+ "data": {
817
+ "text/html": [
818
+ "Tracking run with wandb version 0.16.2"
819
+ ],
820
+ "text/plain": [
821
+ "<IPython.core.display.HTML object>"
822
+ ]
823
+ },
824
+ "metadata": {},
825
+ "output_type": "display_data"
826
+ },
827
+ {
828
+ "data": {
829
+ "text/html": [
830
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141958-5idmkcie</code>"
831
+ ],
832
+ "text/plain": [
833
+ "<IPython.core.display.HTML object>"
834
+ ]
835
+ },
836
+ "metadata": {},
837
+ "output_type": "display_data"
838
+ },
839
+ {
840
+ "data": {
841
+ "text/html": [
842
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
843
+ ],
844
+ "text/plain": [
845
+ "<IPython.core.display.HTML object>"
846
+ ]
847
+ },
848
+ "metadata": {},
849
+ "output_type": "display_data"
850
+ },
851
+ {
852
+ "data": {
853
+ "text/html": [
854
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
855
+ ],
856
+ "text/plain": [
857
+ "<IPython.core.display.HTML object>"
858
+ ]
859
+ },
860
+ "metadata": {},
861
+ "output_type": "display_data"
862
+ },
863
+ {
864
+ "data": {
865
+ "text/html": [
866
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a>"
867
+ ],
868
+ "text/plain": [
869
+ "<IPython.core.display.HTML object>"
870
+ ]
871
+ },
872
+ "metadata": {},
873
+ "output_type": "display_data"
874
+ },
875
+ {
876
+ "data": {
877
+ "text/html": [
878
+ "\n",
879
+ " <div>\n",
880
+ " \n",
881
+ " <progress value='1003' max='851' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
882
+ " [851/851 26:26]\n",
883
+ " </div>\n",
884
+ " "
885
+ ],
886
+ "text/plain": [
887
+ "<IPython.core.display.HTML object>"
888
+ ]
889
+ },
890
+ "metadata": {},
891
+ "output_type": "display_data"
892
+ },
893
+ {
894
+ "data": {
895
+ "text/html": [
896
+ "<style>\n",
897
+ " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
898
+ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
899
+ " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
900
+ " </style>\n",
901
+ "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>▁</td></tr><tr><td>eval/loss</td><td>▁</td></tr><tr><td>eval/runtime</td><td>▁</td></tr><tr><td>eval/samples_per_second</td><td>▁</td></tr><tr><td>eval/steps_per_second</td><td>▁</td></tr><tr><td>train/global_step</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>0.55198</td></tr><tr><td>eval/loss</td><td>0.68442</td></tr><tr><td>eval/runtime</td><td>105.0436</td></tr><tr><td>eval/samples_per_second</td><td>259.168</td></tr><tr><td>eval/steps_per_second</td><td>8.101</td></tr><tr><td>train/global_step</td><td>0</td></tr></table><br/></div></div>"
902
+ ],
903
+ "text/plain": [
904
+ "<IPython.core.display.HTML object>"
905
+ ]
906
+ },
907
+ "metadata": {},
908
+ "output_type": "display_data"
909
+ },
910
+ {
911
+ "data": {
912
+ "text/html": [
913
+ " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
914
+ ],
915
+ "text/plain": [
916
+ "<IPython.core.display.HTML object>"
917
+ ]
918
+ },
919
+ "metadata": {},
920
+ "output_type": "display_data"
921
+ },
922
+ {
923
+ "data": {
924
+ "text/html": [
925
+ "Find logs at: <code>./wandb/run-20240128_141958-5idmkcie/logs</code>"
926
+ ],
927
+ "text/plain": [
928
+ "<IPython.core.display.HTML object>"
929
+ ]
930
+ },
931
+ "metadata": {},
932
+ "output_type": "display_data"
933
+ }
934
+ ],
935
+ "source": [
936
+ "if SUBSAMPLING != 1.0:\n",
937
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
938
+ "else:\n",
939
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
940
+ " \n",
941
+ "wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n",
942
+ "\n",
943
+ "multi_label_trainer.evaluate()\n",
944
+ "wandb.finish()"
945
+ ]
946
+ },
947
+ {
948
+ "cell_type": "code",
949
+ "execution_count": null,
950
+ "metadata": {
951
+ "datalore": {
952
+ "hide_input_from_viewers": true,
953
+ "hide_output_from_viewers": true,
954
+ "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
955
+ "type": "CODE"
956
+ },
957
+ "gather": {
958
+ "logged": 1706449934637
959
+ }
960
+ },
961
+ "outputs": [
962
+ {
963
+ "data": {
964
+ "text/html": [
965
+ "Tracking run with wandb version 0.16.2"
966
+ ],
967
+ "text/plain": [
968
+ "<IPython.core.display.HTML object>"
969
+ ]
970
+ },
971
+ "metadata": {},
972
+ "output_type": "display_data"
973
+ },
974
+ {
975
+ "data": {
976
+ "text/html": [
977
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_142151-2mcc0ibc</code>"
978
+ ],
979
+ "text/plain": [
980
+ "<IPython.core.display.HTML object>"
981
+ ]
982
+ },
983
+ "metadata": {},
984
+ "output_type": "display_data"
985
+ },
986
+ {
987
+ "data": {
988
+ "text/html": [
989
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
990
+ ],
991
+ "text/plain": [
992
+ "<IPython.core.display.HTML object>"
993
+ ]
994
+ },
995
+ "metadata": {},
996
+ "output_type": "display_data"
997
+ },
998
+ {
999
+ "data": {
1000
+ "text/html": [
1001
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
1002
+ ],
1003
+ "text/plain": [
1004
+ "<IPython.core.display.HTML object>"
1005
+ ]
1006
+ },
1007
+ "metadata": {},
1008
+ "output_type": "display_data"
1009
+ },
1010
+ {
1011
+ "data": {
1012
+ "text/html": [
1013
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc</a>"
1014
+ ],
1015
+ "text/plain": [
1016
+ "<IPython.core.display.HTML object>"
1017
+ ]
1018
+ },
1019
+ "metadata": {},
1020
+ "output_type": "display_data"
1021
+ },
1022
+ {
1023
+ "data": {
1024
+ "text/html": [
1025
+ "\n",
1026
+ " <div>\n",
1027
+ " \n",
1028
+ " <progress value='3972' max='11913' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
1029
+ " [ 3972/11913 24:20 < 48:40, 2.72 it/s, Epoch 1/3]\n",
1030
+ " </div>\n",
1031
+ " <table border=\"1\" class=\"dataframe\">\n",
1032
+ " <thead>\n",
1033
+ " <tr style=\"text-align: left;\">\n",
1034
+ " <th>Epoch</th>\n",
1035
+ " <th>Training Loss</th>\n",
1036
+ " <th>Validation Loss</th>\n",
1037
+ " </tr>\n",
1038
+ " </thead>\n",
1039
+ " <tbody>\n",
1040
+ " </tbody>\n",
1041
+ "</table><p>"
1042
+ ],
1043
+ "text/plain": [
1044
+ "<IPython.core.display.HTML object>"
1045
+ ]
1046
+ },
1047
+ "metadata": {},
1048
+ "output_type": "display_data"
1049
+ },
1050
+ {
1051
+ "name": "stderr",
1052
+ "output_type": "stream",
1053
+ "text": [
1054
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.6s\n",
1055
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 22.7s\n",
1056
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 14.0s\n",
1057
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 15.2s\n",
1058
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 14.0s\n",
1059
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 12.4s\n",
1060
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 13.4s\n"
1061
+ ]
1062
+ }
1063
+ ],
1064
+ "source": [
1065
+ "if SUBSAMPLING != 1.0:\n",
1066
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
1067
+ "else:\n",
1068
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
1069
+ " \n",
1070
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n",
1071
+ "\n",
1072
+ "multi_label_trainer.train()\n",
1073
+ "wandb.finish()"
1074
+ ]
1075
+ },
1076
+ {
1077
+ "cell_type": "markdown",
1078
+ "metadata": {},
1079
+ "source": [
1080
+ "### Evaluation"
1081
+ ]
1082
+ },
1083
+ {
1084
+ "cell_type": "markdown",
1085
+ "metadata": {},
1086
+ "source": [
1087
+ "We instantiate a classifier `pipeline` and push it to CUDA."
1088
+ ]
1089
+ },
1090
+ {
1091
+ "cell_type": "code",
1092
+ "execution_count": null,
1093
+ "metadata": {
1094
+ "datalore": {
1095
+ "hide_input_from_viewers": true,
1096
+ "hide_output_from_viewers": true,
1097
+ "node_id": "kHoUdBeqcyVXDSGv54C4aE",
1098
+ "type": "CODE"
1099
+ },
1100
+ "gather": {
1101
+ "logged": 1706411459928
1102
+ }
1103
+ },
1104
+ "outputs": [],
1105
+ "source": [
1106
+ "classifier = pipeline(\"text-classification\", \n",
1107
+ " model, \n",
1108
+ " tokenizer=tokenizer, \n",
1109
+ " device=\"cuda:0\")"
1110
+ ]
1111
+ },
1112
+ {
1113
+ "cell_type": "markdown",
1114
+ "metadata": {},
1115
+ "source": [
1116
+ "We use the same tokenizer used for training to tokenize/encode the validation set."
1117
+ ]
1118
+ },
1119
+ {
1120
+ "cell_type": "code",
1121
+ "execution_count": null,
1122
+ "metadata": {
1123
+ "datalore": {
1124
+ "hide_input_from_viewers": true,
1125
+ "hide_output_from_viewers": true,
1126
+ "node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
1127
+ "type": "CODE"
1128
+ },
1129
+ "gather": {
1130
+ "logged": 1706411523285
1131
+ }
1132
+ },
1133
+ "outputs": [],
1134
+ "source": [
1135
+ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
1136
+ " max_length=None, \n",
1137
+ " padding='max_length', \n",
1138
+ " return_token_type_ids=True, \n",
1139
+ " truncation=True)"
1140
+ ]
1141
+ },
1142
+ {
1143
+ "cell_type": "markdown",
1144
+ "metadata": {},
1145
+ "source": [
1146
+ "Once we've made the data loadable by putting it into a `DataLoader`, we "
1147
+ ]
1148
+ },
1149
+ {
1150
+ "cell_type": "code",
1151
+ "execution_count": null,
1152
+ "metadata": {
1153
+ "datalore": {
1154
+ "hide_input_from_viewers": true,
1155
+ "hide_output_from_viewers": true,
1156
+ "node_id": "MWfGq2tTkJNzFiDoUPq2X7",
1157
+ "type": "CODE"
1158
+ },
1159
+ "gather": {
1160
+ "logged": 1706411543379
1161
+ }
1162
+ },
1163
+ "outputs": [],
1164
+ "source": [
1165
+ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
1166
+ " torch.tensor(test_encodings['attention_mask']), \n",
1167
+ " torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
1168
+ " torch.tensor(test_encodings['token_type_ids']))\n",
1169
+ "test_dataloader = torch.utils.data.DataLoader(test_data, \n",
1170
+ " sampler=torch.utils.data.SequentialSampler(test_data), \n",
1171
+ " batch_size=BATCH_SIZE)"
1172
+ ]
1173
+ },
1174
+ {
1175
+ "cell_type": "code",
1176
+ "execution_count": null,
1177
+ "metadata": {
1178
+ "datalore": {
1179
+ "hide_input_from_viewers": true,
1180
+ "hide_output_from_viewers": true,
1181
+ "node_id": "1SJCSrQTRCexFCNCIyRrzL",
1182
+ "type": "CODE"
1183
+ },
1184
+ "gather": {
1185
+ "logged": 1706411587843
1186
+ }
1187
+ },
1188
+ "outputs": [],
1189
+ "source": [
1190
+ "model.eval()\n",
1191
+ "\n",
1192
+ "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
1193
+ "\n",
1194
+ "for i, batch in enumerate(test_dataloader):\n",
1195
+ " batch = tuple(t.to(device) for t in batch)\n",
1196
+ " \n",
1197
+ " # Unpack the inputs from our dataloader\n",
1198
+ " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
1199
+ " \n",
1200
+ " with torch.no_grad():\n",
1201
+ " outs = model(b_input_ids, attention_mask=b_input_mask)\n",
1202
+ " b_logit_pred = outs[0]\n",
1203
+ " pred_label = torch.sigmoid(b_logit_pred)\n",
1204
+ "\n",
1205
+ " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
1206
+ " pred_label = pred_label.to('cpu').numpy()\n",
1207
+ " b_labels = b_labels.to('cpu').numpy()\n",
1208
+ "\n",
1209
+ " tokenized_texts.append(b_input_ids)\n",
1210
+ " logit_preds.append(b_logit_pred)\n",
1211
+ " true_labels.append(b_labels)\n",
1212
+ " pred_labels.append(pred_label)\n",
1213
+ "\n",
1214
+ "# Flatten outputs\n",
1215
+ "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
1216
+ "pred_labels = [item for sublist in pred_labels for item in sublist]\n",
1217
+ "true_labels = [item for sublist in true_labels for item in sublist]\n",
1218
+ "\n",
1219
+ "# Converting flattened binary values to boolean values\n",
1220
+ "true_bools = [tl == 1 for tl in true_labels]\n",
1221
+ "pred_bools = [pl > 0.50 for pl in pred_labels] "
1222
+ ]
1223
+ },
1224
+ {
1225
+ "cell_type": "markdown",
1226
+ "metadata": {},
1227
+ "source": [
1228
+ "We create a classification report:"
1229
+ ]
1230
+ },
1231
+ {
1232
+ "cell_type": "code",
1233
+ "execution_count": null,
1234
+ "metadata": {
1235
+ "datalore": {
1236
+ "hide_input_from_viewers": true,
1237
+ "hide_output_from_viewers": true,
1238
+ "node_id": "eBprrgF086mznPbPVBpOLS",
1239
+ "type": "CODE"
1240
+ },
1241
+ "gather": {
1242
+ "logged": 1706411588249
1243
+ }
1244
+ },
1245
+ "outputs": [],
1246
+ "source": [
1247
+ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
1248
+ "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
1249
+ "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
1250
+ "print(clf_report)"
1251
+ ]
1252
+ },
1253
+ {
1254
+ "cell_type": "markdown",
1255
+ "metadata": {},
1256
+ "source": [
1257
+ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
1258
+ ]
1259
+ },
1260
+ {
1261
+ "cell_type": "code",
1262
+ "execution_count": null,
1263
+ "metadata": {
1264
+ "datalore": {
1265
+ "hide_input_from_viewers": true,
1266
+ "hide_output_from_viewers": true,
1267
+ "node_id": "yELHY0IEwMlMw3x6e7hoD1",
1268
+ "type": "CODE"
1269
+ },
1270
+ "gather": {
1271
+ "logged": 1706411588638
1272
+ }
1273
+ },
1274
+ "outputs": [],
1275
+ "source": [
1276
+ "# Creating a map of class names from class numbers\n",
1277
+ "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
1278
+ ]
1279
+ },
1280
+ {
1281
+ "cell_type": "code",
1282
+ "execution_count": null,
1283
+ "metadata": {
1284
+ "datalore": {
1285
+ "hide_input_from_viewers": true,
1286
+ "hide_output_from_viewers": true,
1287
+ "node_id": "jH0S35dDteUch01sa6me6e",
1288
+ "type": "CODE"
1289
+ },
1290
+ "gather": {
1291
+ "logged": 1706411589004
1292
+ }
1293
+ },
1294
+ "outputs": [],
1295
+ "source": [
1296
+ "true_label_idxs, pred_label_idxs = [], []\n",
1297
+ "\n",
1298
+ "for vals in true_bools:\n",
1299
+ " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
1300
+ "for vals in pred_bools:\n",
1301
+ " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
1302
+ ]
1303
+ },
1304
+ {
1305
+ "cell_type": "code",
1306
+ "execution_count": null,
1307
+ "metadata": {
1308
+ "datalore": {
1309
+ "hide_input_from_viewers": true,
1310
+ "hide_output_from_viewers": true,
1311
+ "node_id": "h4vHL8XdGpayZ6xLGJUF6F",
1312
+ "type": "CODE"
1313
+ },
1314
+ "gather": {
1315
+ "logged": 1706411589301
1316
+ }
1317
+ },
1318
+ "outputs": [],
1319
+ "source": [
1320
+ "true_label_texts, pred_label_texts = [], []\n",
1321
+ "\n",
1322
+ "for vals in true_label_idxs:\n",
1323
+ " if vals:\n",
1324
+ " true_label_texts.append([idx2label[val] for val in vals])\n",
1325
+ " else:\n",
1326
+ " true_label_texts.append(vals)\n",
1327
+ "\n",
1328
+ "for vals in pred_label_idxs:\n",
1329
+ " if vals:\n",
1330
+ " pred_label_texts.append([idx2label[val] for val in vals])\n",
1331
+ " else:\n",
1332
+ " pred_label_texts.append(vals)"
1333
+ ]
1334
+ },
1335
+ {
1336
+ "cell_type": "code",
1337
+ "execution_count": null,
1338
+ "metadata": {
1339
+ "datalore": {
1340
+ "hide_input_from_viewers": true,
1341
+ "hide_output_from_viewers": true,
1342
+ "node_id": "SxUmVHfQISEeptg1SawOmB",
1343
+ "type": "CODE"
1344
+ },
1345
+ "gather": {
1346
+ "logged": 1706411591952
1347
+ }
1348
+ },
1349
+ "outputs": [],
1350
+ "source": [
1351
+ "symptom_texts = [tokenizer.decode(text,\n",
1352
+ " skip_special_tokens=True,\n",
1353
+ " clean_up_tokenization_spaces=False) for text in tokenized_texts]"
1354
+ ]
1355
+ },
1356
+ {
1357
+ "cell_type": "code",
1358
+ "execution_count": null,
1359
+ "metadata": {
1360
+ "datalore": {
1361
+ "hide_input_from_viewers": true,
1362
+ "hide_output_from_viewers": true,
1363
+ "node_id": "BxFNigNGRLTOqraI55BPSH",
1364
+ "type": "CODE"
1365
+ },
1366
+ "gather": {
1367
+ "logged": 1706411592512
1368
+ }
1369
+ },
1370
+ "outputs": [],
1371
+ "source": [
1372
+ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
1373
+ " 'true_labels': true_label_texts, \n",
1374
+ " 'pred_labels':pred_label_texts})\n",
1375
+ "comparisons_df.to_csv('comparisons.csv')\n",
1376
+ "comparisons_df"
1377
+ ]
1378
+ }
1379
+ ],
1380
+ "metadata": {
1381
+ "datalore": {
1382
+ "base_environment": "default",
1383
+ "computation_mode": "JUPYTER",
1384
+ "package_manager": "pip",
1385
+ "packages": [
1386
+ {
1387
+ "name": "datasets",
1388
+ "source": "PIP",
1389
+ "version": "2.16.1"
1390
+ },
1391
+ {
1392
+ "name": "torch",
1393
+ "source": "PIP",
1394
+ "version": "2.1.2"
1395
+ },
1396
+ {
1397
+ "name": "accelerate",
1398
+ "source": "PIP",
1399
+ "version": "0.26.1"
1400
+ }
1401
+ ],
1402
+ "report_row_ids": [
1403
+ "un8W7ez7ZwoGb5Co6nydEV",
1404
+ "40nN9Hvgi1clHNV5RAemI5",
1405
+ "TgRD90H5NSPpKS41OeXI1w",
1406
+ "ZOm5BfUs3h1EGLaUkBGeEB",
1407
+ "kOP0CZWNSk6vqE3wkPp7Vc",
1408
+ "W4PWcOu2O2pRaZyoE2W80h",
1409
+ "RolbOnQLIftk0vy9mIcz5M",
1410
+ "8OPhUgbaNJmOdiq5D3a6vK",
1411
+ "5Qrt3jSvSrpK6Ne1hS6shL",
1412
+ "hTq7nFUrovN5Ao4u6dIYWZ",
1413
+ "I8WNZLpJ1DVP2wiCW7YBIB",
1414
+ "SawhU3I9BewSE1XBPstpNJ",
1415
+ "80EtLEl2FIE4FqbWnUD3nT"
1416
+ ],
1417
+ "version": 3
1418
+ },
1419
+ "kernelspec": {
1420
+ "display_name": "Python 3.8 - Pytorch and Tensorflow",
1421
+ "language": "python",
1422
+ "name": "python38-azureml-pt-tf"
1423
+ },
1424
+ "language_info": {
1425
+ "codemirror_mode": {
1426
+ "name": "ipython",
1427
+ "version": 3
1428
+ },
1429
+ "file_extension": ".py",
1430
+ "mimetype": "text/x-python",
1431
+ "name": "python",
1432
+ "nbconvert_exporter": "python",
1433
+ "pygments_lexer": "ipython3",
1434
+ "version": "3.8.5"
1435
+ },
1436
+ "microsoft": {
1437
+ "host": {
1438
+ "AzureML": {
1439
+ "notebookHasBeenCompleted": true
1440
+ }
1441
+ },
1442
+ "ms_spell_check": {
1443
+ "ms_spell_check_language": "en"
1444
+ }
1445
+ },
1446
+ "nteract": {
1447
+ "version": "nteract-front-end@1.0.0"
1448
+ }
1449
+ },
1450
+ "nbformat": 4,
1451
+ "nbformat_minor": 4
1452
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-16-26-9Z.ipynb ADDED
@@ -0,0 +1,1246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {}
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "source": [
15
+ "# %pip install accelerate -U"
16
+ ],
17
+ "outputs": [],
18
+ "execution_count": 1,
19
+ "metadata": {
20
+ "nteract": {
21
+ "transient": {
22
+ "deleting": false
23
+ }
24
+ },
25
+ "tags": []
26
+ }
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "source": [
31
+ "%pip install transformers datasets shap watermark wandb scikit-multilearn"
32
+ ],
33
+ "outputs": [
34
+ {
35
+ "output_type": "stream",
36
+ "name": "stdout",
37
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nCollecting scikit-multilearn\n Downloading scikit_multilearn-0.2.0-py3-none-any.whl (89 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.4/89.4 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nInstalling collected packages: scikit-multilearn\nSuccessfully installed scikit-multilearn-0.2.0\nNote: you may need to restart the kernel to use updated packages.\n"
38
+ }
39
+ ],
40
+ "execution_count": 1,
41
+ "metadata": {
42
+ "nteract": {
43
+ "transient": {
44
+ "deleting": false
45
+ }
46
+ }
47
+ }
48
+ },
49
+ {
50
+ "cell_type": "code",
51
+ "source": [
52
+ "import pandas as pd\n",
53
+ "import numpy as np\n",
54
+ "import torch\n",
55
+ "import os\n",
56
+ "from typing import List\n",
57
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
58
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
59
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
60
+ "from pyarrow import Table\n",
61
+ "import shap\n",
62
+ "import wandb\n",
63
+ "from skmultilearn.problem_transform import LabelPowerset\n",
64
+ "\n",
65
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
66
+ "\n",
67
+ "%load_ext watermark"
68
+ ],
69
+ "outputs": [
70
+ {
71
+ "output_type": "stream",
72
+ "name": "stderr",
73
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 15:09:42.856486: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 15:09:43.818179: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 15:09:43.818307: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 15:09:43.818321: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
74
+ }
75
+ ],
76
+ "execution_count": 2,
77
+ "metadata": {
78
+ "datalore": {
79
+ "hide_input_from_viewers": false,
80
+ "hide_output_from_viewers": false,
81
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
82
+ "report_properties": {
83
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
84
+ },
85
+ "type": "CODE"
86
+ },
87
+ "gather": {
88
+ "logged": 1706454586481
89
+ },
90
+ "tags": []
91
+ }
92
+ },
93
+ {
94
+ "cell_type": "code",
95
+ "source": [
96
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
97
+ "\n",
98
+ "SEED: int = 42\n",
99
+ "\n",
100
+ "BATCH_SIZE: int = 16\n",
101
+ "EPOCHS: int = 3\n",
102
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
103
+ "\n",
104
+ "CLASS_NAMES: List[str] = [\"DIED\",\n",
105
+ " \"ER_VISIT\",\n",
106
+ " \"HOSPITAL\",\n",
107
+ " \"OFC_VISIT\",\n",
108
+ " #\"X_STAY\", # pruned\n",
109
+ " #\"DISABLE\", # pruned\n",
110
+ " #\"D_PRESENTED\" # pruned\n",
111
+ " ]\n",
112
+ "\n",
113
+ "\n",
114
+ "\n",
115
+ "\n",
116
+ "# WandB configuration\n",
117
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
118
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
119
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
120
+ ],
121
+ "outputs": [],
122
+ "execution_count": 3,
123
+ "metadata": {
124
+ "collapsed": false,
125
+ "gather": {
126
+ "logged": 1706454586654
127
+ },
128
+ "jupyter": {
129
+ "outputs_hidden": false
130
+ }
131
+ }
132
+ },
133
+ {
134
+ "cell_type": "code",
135
+ "source": [
136
+ "%watermark --iversion"
137
+ ],
138
+ "outputs": [
139
+ {
140
+ "output_type": "stream",
141
+ "name": "stdout",
142
+ "text": "shap : 0.44.1\nlogging: 0.5.1.2\npandas : 2.0.2\nnumpy : 1.23.5\ntorch : 1.12.0\nwandb : 0.16.2\nre : 2.2.1\n\n"
143
+ }
144
+ ],
145
+ "execution_count": 4,
146
+ "metadata": {
147
+ "collapsed": false,
148
+ "jupyter": {
149
+ "outputs_hidden": false
150
+ }
151
+ }
152
+ },
153
+ {
154
+ "cell_type": "code",
155
+ "source": [
156
+ "!nvidia-smi"
157
+ ],
158
+ "outputs": [
159
+ {
160
+ "output_type": "stream",
161
+ "name": "stdout",
162
+ "text": "Sun Jan 28 15:09:47 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 30C P0 38W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 29C P0 38W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
163
+ }
164
+ ],
165
+ "execution_count": 5,
166
+ "metadata": {
167
+ "datalore": {
168
+ "hide_input_from_viewers": true,
169
+ "hide_output_from_viewers": true,
170
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
171
+ "type": "CODE"
172
+ }
173
+ }
174
+ },
175
+ {
176
+ "cell_type": "markdown",
177
+ "source": [
178
+ "## Loading the data set"
179
+ ],
180
+ "metadata": {
181
+ "datalore": {
182
+ "hide_input_from_viewers": false,
183
+ "hide_output_from_viewers": false,
184
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
185
+ "report_properties": {
186
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
187
+ },
188
+ "type": "MD"
189
+ }
190
+ }
191
+ },
192
+ {
193
+ "cell_type": "code",
194
+ "source": [
195
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
196
+ ],
197
+ "outputs": [],
198
+ "execution_count": 7,
199
+ "metadata": {
200
+ "collapsed": false,
201
+ "gather": {
202
+ "logged": 1706449040507
203
+ },
204
+ "jupyter": {
205
+ "outputs_hidden": false
206
+ }
207
+ }
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "source": [
212
+ "dataset"
213
+ ],
214
+ "outputs": [
215
+ {
216
+ "output_type": "execute_result",
217
+ "execution_count": 8,
218
+ "data": {
219
+ "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 272238\n })\n})"
220
+ },
221
+ "metadata": {}
222
+ }
223
+ ],
224
+ "execution_count": 8,
225
+ "metadata": {
226
+ "collapsed": false,
227
+ "gather": {
228
+ "logged": 1706449044205
229
+ },
230
+ "jupyter": {
231
+ "outputs_hidden": false,
232
+ "source_hidden": false
233
+ },
234
+ "nteract": {
235
+ "transient": {
236
+ "deleting": false
237
+ }
238
+ }
239
+ }
240
+ },
241
+ {
242
+ "cell_type": "code",
243
+ "source": [
244
+ "SUBSAMPLING: float = 0.1"
245
+ ],
246
+ "outputs": [],
247
+ "execution_count": 9,
248
+ "metadata": {}
249
+ },
250
+ {
251
+ "cell_type": "code",
252
+ "source": [
253
+ "def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n",
254
+ " res = DatasetDict()\n",
255
+ "\n",
256
+ " res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n",
257
+ " res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n",
258
+ " res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n",
259
+ " \n",
260
+ " return res"
261
+ ],
262
+ "outputs": [],
263
+ "execution_count": 10,
264
+ "metadata": {
265
+ "collapsed": false,
266
+ "gather": {
267
+ "logged": 1706449378281
268
+ },
269
+ "jupyter": {
270
+ "outputs_hidden": false,
271
+ "source_hidden": false
272
+ },
273
+ "nteract": {
274
+ "transient": {
275
+ "deleting": false
276
+ }
277
+ }
278
+ }
279
+ },
280
+ {
281
+ "cell_type": "code",
282
+ "source": [
283
+ "dataset = minisample(dataset, SUBSAMPLING)"
284
+ ],
285
+ "outputs": [],
286
+ "execution_count": 11,
287
+ "metadata": {
288
+ "collapsed": false,
289
+ "gather": {
290
+ "logged": 1706449384162
291
+ },
292
+ "jupyter": {
293
+ "outputs_hidden": false,
294
+ "source_hidden": false
295
+ },
296
+ "nteract": {
297
+ "transient": {
298
+ "deleting": false
299
+ }
300
+ }
301
+ }
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "source": [
306
+ "dataset"
307
+ ],
308
+ "outputs": [
309
+ {
310
+ "output_type": "execute_result",
311
+ "execution_count": 12,
312
+ "data": {
313
+ "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 127044\n })\n test: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 27224\n })\n val: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 27224\n })\n})"
314
+ },
315
+ "metadata": {}
316
+ }
317
+ ],
318
+ "execution_count": 12,
319
+ "metadata": {
320
+ "collapsed": false,
321
+ "gather": {
322
+ "logged": 1706449387981
323
+ },
324
+ "jupyter": {
325
+ "outputs_hidden": false,
326
+ "source_hidden": false
327
+ },
328
+ "nteract": {
329
+ "transient": {
330
+ "deleting": false
331
+ }
332
+ }
333
+ }
334
+ },
335
+ {
336
+ "cell_type": "markdown",
337
+ "source": [
338
+ "We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
339
+ ],
340
+ "metadata": {
341
+ "nteract": {
342
+ "transient": {
343
+ "deleting": false
344
+ }
345
+ }
346
+ }
347
+ },
348
+ {
349
+ "cell_type": "code",
350
+ "source": [
351
+ "ds = DatasetDict()\n",
352
+ "\n",
353
+ "for i in [\"test\", \"train\", \"val\"]:\n",
354
+ " tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
355
+ " ds[i] = Dataset(tab)\n",
356
+ "\n",
357
+ "dataset = ds"
358
+ ],
359
+ "outputs": [],
360
+ "execution_count": 13,
361
+ "metadata": {
362
+ "collapsed": false,
363
+ "gather": {
364
+ "logged": 1706449443055
365
+ },
366
+ "jupyter": {
367
+ "outputs_hidden": false,
368
+ "source_hidden": false
369
+ },
370
+ "nteract": {
371
+ "transient": {
372
+ "deleting": false
373
+ }
374
+ }
375
+ }
376
+ },
377
+ {
378
+ "cell_type": "markdown",
379
+ "source": [
380
+ "### Tokenisation and encoding"
381
+ ],
382
+ "metadata": {}
383
+ },
384
+ {
385
+ "cell_type": "code",
386
+ "source": [
387
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
388
+ ],
389
+ "outputs": [],
390
+ "execution_count": 14,
391
+ "metadata": {
392
+ "datalore": {
393
+ "hide_input_from_viewers": true,
394
+ "hide_output_from_viewers": true,
395
+ "node_id": "I7n646PIscsUZRoHu6m7zm",
396
+ "type": "CODE"
397
+ },
398
+ "gather": {
399
+ "logged": 1706449638377
400
+ }
401
+ }
402
+ },
403
+ {
404
+ "cell_type": "code",
405
+ "source": [
406
+ "def tokenize_and_encode(examples):\n",
407
+ " return tokenizer(examples[\"text\"], truncation=True)"
408
+ ],
409
+ "outputs": [],
410
+ "execution_count": 15,
411
+ "metadata": {
412
+ "datalore": {
413
+ "hide_input_from_viewers": true,
414
+ "hide_output_from_viewers": true,
415
+ "node_id": "QBLOSI0yVIslV7v7qX9ZC3",
416
+ "type": "CODE"
417
+ },
418
+ "gather": {
419
+ "logged": 1706449642580
420
+ }
421
+ }
422
+ },
423
+ {
424
+ "cell_type": "code",
425
+ "source": [
426
+ "cols = dataset[\"train\"].column_names\n",
427
+ "cols.remove(\"labels\")\n",
428
+ "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
429
+ ],
430
+ "outputs": [
431
+ {
432
+ "output_type": "stream",
433
+ "name": "stderr",
434
+ "text": "Map: 100%|██████████| 27224/27224 [00:10<00:00, 2638.52 examples/s]\nMap: 100%|██████████| 127044/127044 [00:48<00:00, 2633.40 examples/s]\nMap: 100%|██████████| 27224/27224 [00:10<00:00, 2613.19 examples/s]\n"
435
+ }
436
+ ],
437
+ "execution_count": 16,
438
+ "metadata": {
439
+ "datalore": {
440
+ "hide_input_from_viewers": true,
441
+ "hide_output_from_viewers": true,
442
+ "node_id": "slHeNysZOX9uWS9PB7jFDb",
443
+ "type": "CODE"
444
+ },
445
+ "gather": {
446
+ "logged": 1706449721161
447
+ }
448
+ }
449
+ },
450
+ {
451
+ "cell_type": "markdown",
452
+ "source": [
453
+ "### Training"
454
+ ],
455
+ "metadata": {}
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "source": [
460
+ "class MultiLabelTrainer(Trainer):\n",
461
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
462
+ " labels = inputs.pop(\"labels\")\n",
463
+ " outputs = model(**inputs)\n",
464
+ " logits = outputs.logits\n",
465
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
466
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
467
+ " labels.float().view(-1, self.model.config.num_labels))\n",
468
+ " return (loss, outputs) if return_outputs else loss"
469
+ ],
470
+ "outputs": [],
471
+ "execution_count": 17,
472
+ "metadata": {
473
+ "datalore": {
474
+ "hide_input_from_viewers": true,
475
+ "hide_output_from_viewers": true,
476
+ "node_id": "itXWkbDw9sqbkMuDP84QoT",
477
+ "type": "CODE"
478
+ },
479
+ "gather": {
480
+ "logged": 1706449743072
481
+ }
482
+ }
483
+ },
484
+ {
485
+ "cell_type": "code",
486
+ "source": [
487
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
488
+ ],
489
+ "outputs": [
490
+ {
491
+ "output_type": "stream",
492
+ "name": "stderr",
493
+ "text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
494
+ }
495
+ ],
496
+ "execution_count": 18,
497
+ "metadata": {
498
+ "datalore": {
499
+ "hide_input_from_viewers": true,
500
+ "hide_output_from_viewers": true,
501
+ "node_id": "ZQU7aW6TV45VmhHOQRzcnF",
502
+ "type": "CODE"
503
+ },
504
+ "gather": {
505
+ "logged": 1706449761205
506
+ }
507
+ }
508
+ },
509
+ {
510
+ "cell_type": "code",
511
+ "source": [
512
+ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
513
+ " y_pred = torch.from_numpy(y_pred)\n",
514
+ " y_true = torch.from_numpy(y_true)\n",
515
+ "\n",
516
+ " if sigmoid:\n",
517
+ " y_pred = y_pred.sigmoid()\n",
518
+ "\n",
519
+ " return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
520
+ ],
521
+ "outputs": [],
522
+ "execution_count": 19,
523
+ "metadata": {
524
+ "datalore": {
525
+ "hide_input_from_viewers": true,
526
+ "hide_output_from_viewers": true,
527
+ "node_id": "swhgyyyxoGL8HjnXJtMuSW",
528
+ "type": "CODE"
529
+ },
530
+ "gather": {
531
+ "logged": 1706449761541
532
+ }
533
+ }
534
+ },
535
+ {
536
+ "cell_type": "code",
537
+ "source": [
538
+ "def compute_metrics(eval_pred):\n",
539
+ " predictions, labels = eval_pred\n",
540
+ " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
541
+ ],
542
+ "outputs": [],
543
+ "execution_count": 20,
544
+ "metadata": {
545
+ "datalore": {
546
+ "hide_input_from_viewers": true,
547
+ "hide_output_from_viewers": true,
548
+ "node_id": "1Uq3HtkaBxtHNAnSwit5cI",
549
+ "type": "CODE"
550
+ },
551
+ "gather": {
552
+ "logged": 1706449761720
553
+ }
554
+ }
555
+ },
556
+ {
557
+ "cell_type": "code",
558
+ "source": [
559
+ "args = TrainingArguments(\n",
560
+ " output_dir=\"vaers\",\n",
561
+ " evaluation_strategy=\"epoch\",\n",
562
+ " learning_rate=2e-5,\n",
563
+ " per_device_train_batch_size=BATCH_SIZE,\n",
564
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
565
+ " num_train_epochs=EPOCHS,\n",
566
+ " weight_decay=.01,\n",
567
+ " logging_steps=1,\n",
568
+ " run_name=f\"daedra-training\",\n",
569
+ " report_to=[\"wandb\"]\n",
570
+ ")"
571
+ ],
572
+ "outputs": [],
573
+ "execution_count": 21,
574
+ "metadata": {
575
+ "datalore": {
576
+ "hide_input_from_viewers": true,
577
+ "hide_output_from_viewers": true,
578
+ "node_id": "1iPZOTKPwSkTgX5dORqT89",
579
+ "type": "CODE"
580
+ },
581
+ "gather": {
582
+ "logged": 1706449761893
583
+ }
584
+ }
585
+ },
586
+ {
587
+ "cell_type": "code",
588
+ "source": [
589
+ "multi_label_trainer = MultiLabelTrainer(\n",
590
+ " model, \n",
591
+ " args, \n",
592
+ " train_dataset=ds_enc[\"train\"], \n",
593
+ " eval_dataset=ds_enc[\"test\"], \n",
594
+ " compute_metrics=compute_metrics, \n",
595
+ " tokenizer=tokenizer\n",
596
+ ")"
597
+ ],
598
+ "outputs": [],
599
+ "execution_count": 22,
600
+ "metadata": {
601
+ "datalore": {
602
+ "hide_input_from_viewers": true,
603
+ "hide_output_from_viewers": true,
604
+ "node_id": "bnRkNvRYltLun6gCEgL7v0",
605
+ "type": "CODE"
606
+ },
607
+ "gather": {
608
+ "logged": 1706449769103
609
+ }
610
+ }
611
+ },
612
+ {
613
+ "cell_type": "code",
614
+ "source": [
615
+ "if SUBSAMPLING != 1.0:\n",
616
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
617
+ "else:\n",
618
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
619
+ " \n",
620
+ "wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n",
621
+ "\n",
622
+ "multi_label_trainer.evaluate()\n",
623
+ "wandb.finish()"
624
+ ],
625
+ "outputs": [
626
+ {
627
+ "output_type": "stream",
628
+ "name": "stderr",
629
+ "text": "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
630
+ },
631
+ {
632
+ "output_type": "display_data",
633
+ "data": {
634
+ "text/html": "Tracking run with wandb version 0.16.2",
635
+ "text/plain": "<IPython.core.display.HTML object>"
636
+ },
637
+ "metadata": {}
638
+ },
639
+ {
640
+ "output_type": "display_data",
641
+ "data": {
642
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141956-9lniqjvz</code>",
643
+ "text/plain": "<IPython.core.display.HTML object>"
644
+ },
645
+ "metadata": {}
646
+ },
647
+ {
648
+ "output_type": "display_data",
649
+ "data": {
650
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>",
651
+ "text/plain": "<IPython.core.display.HTML object>"
652
+ },
653
+ "metadata": {}
654
+ },
655
+ {
656
+ "output_type": "display_data",
657
+ "data": {
658
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>",
659
+ "text/plain": "<IPython.core.display.HTML object>"
660
+ },
661
+ "metadata": {}
662
+ },
663
+ {
664
+ "output_type": "display_data",
665
+ "data": {
666
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a>",
667
+ "text/plain": "<IPython.core.display.HTML object>"
668
+ },
669
+ "metadata": {}
670
+ },
671
+ {
672
+ "output_type": "display_data",
673
+ "data": {
674
+ "text/html": "Finishing last run (ID:9lniqjvz) before initializing another...",
675
+ "text/plain": "<IPython.core.display.HTML object>"
676
+ },
677
+ "metadata": {}
678
+ },
679
+ {
680
+ "output_type": "display_data",
681
+ "data": {
682
+ "text/html": " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)",
683
+ "text/plain": "<IPython.core.display.HTML object>"
684
+ },
685
+ "metadata": {}
686
+ },
687
+ {
688
+ "output_type": "display_data",
689
+ "data": {
690
+ "text/html": "Find logs at: <code>./wandb/run-20240128_141956-9lniqjvz/logs</code>",
691
+ "text/plain": "<IPython.core.display.HTML object>"
692
+ },
693
+ "metadata": {}
694
+ },
695
+ {
696
+ "output_type": "display_data",
697
+ "data": {
698
+ "text/html": "Successfully finished last run (ID:9lniqjvz). Initializing new run:<br/>",
699
+ "text/plain": "<IPython.core.display.HTML object>"
700
+ },
701
+ "metadata": {}
702
+ },
703
+ {
704
+ "output_type": "display_data",
705
+ "data": {
706
+ "text/html": "Tracking run with wandb version 0.16.2",
707
+ "text/plain": "<IPython.core.display.HTML object>"
708
+ },
709
+ "metadata": {}
710
+ },
711
+ {
712
+ "output_type": "display_data",
713
+ "data": {
714
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141958-5idmkcie</code>",
715
+ "text/plain": "<IPython.core.display.HTML object>"
716
+ },
717
+ "metadata": {}
718
+ },
719
+ {
720
+ "output_type": "display_data",
721
+ "data": {
722
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>",
723
+ "text/plain": "<IPython.core.display.HTML object>"
724
+ },
725
+ "metadata": {}
726
+ },
727
+ {
728
+ "output_type": "display_data",
729
+ "data": {
730
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>",
731
+ "text/plain": "<IPython.core.display.HTML object>"
732
+ },
733
+ "metadata": {}
734
+ },
735
+ {
736
+ "output_type": "display_data",
737
+ "data": {
738
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a>",
739
+ "text/plain": "<IPython.core.display.HTML object>"
740
+ },
741
+ "metadata": {}
742
+ },
743
+ {
744
+ "output_type": "display_data",
745
+ "data": {
746
+ "text/html": "\n <div>\n \n <progress value='1003' max='851' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [851/851 26:26]\n </div>\n ",
747
+ "text/plain": "<IPython.core.display.HTML object>"
748
+ },
749
+ "metadata": {}
750
+ },
751
+ {
752
+ "output_type": "display_data",
753
+ "data": {
754
+ "text/html": "<style>\n table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n </style>\n<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>▁</td></tr><tr><td>eval/loss</td><td>▁</td></tr><tr><td>eval/runtime</td><td>▁</td></tr><tr><td>eval/samples_per_second</td><td>▁</td></tr><tr><td>eval/steps_per_second</td><td>▁</td></tr><tr><td>train/global_step</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>0.55198</td></tr><tr><td>eval/loss</td><td>0.68442</td></tr><tr><td>eval/runtime</td><td>105.0436</td></tr><tr><td>eval/samples_per_second</td><td>259.168</td></tr><tr><td>eval/steps_per_second</td><td>8.101</td></tr><tr><td>train/global_step</td><td>0</td></tr></table><br/></div></div>",
755
+ "text/plain": "<IPython.core.display.HTML object>"
756
+ },
757
+ "metadata": {}
758
+ },
759
+ {
760
+ "output_type": "display_data",
761
+ "data": {
762
+ "text/html": " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)",
763
+ "text/plain": "<IPython.core.display.HTML object>"
764
+ },
765
+ "metadata": {}
766
+ },
767
+ {
768
+ "output_type": "display_data",
769
+ "data": {
770
+ "text/html": "Find logs at: <code>./wandb/run-20240128_141958-5idmkcie/logs</code>",
771
+ "text/plain": "<IPython.core.display.HTML object>"
772
+ },
773
+ "metadata": {}
774
+ }
775
+ ],
776
+ "execution_count": 23,
777
+ "metadata": {
778
+ "datalore": {
779
+ "hide_input_from_viewers": true,
780
+ "hide_output_from_viewers": true,
781
+ "node_id": "LO54PlDkWQdFrzV25FvduB",
782
+ "type": "CODE"
783
+ },
784
+ "gather": {
785
+ "logged": 1706449880674
786
+ }
787
+ }
788
+ },
789
+ {
790
+ "cell_type": "code",
791
+ "source": [
792
+ "if SUBSAMPLING != 1.0:\n",
793
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
794
+ "else:\n",
795
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
796
+ " \n",
797
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n",
798
+ "\n",
799
+ "multi_label_trainer.train()\n",
800
+ "wandb.finish()"
801
+ ],
802
+ "outputs": [
803
+ {
804
+ "output_type": "display_data",
805
+ "data": {
806
+ "text/html": "Tracking run with wandb version 0.16.2",
807
+ "text/plain": "<IPython.core.display.HTML object>"
808
+ },
809
+ "metadata": {}
810
+ },
811
+ {
812
+ "output_type": "display_data",
813
+ "data": {
814
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_142151-2mcc0ibc</code>",
815
+ "text/plain": "<IPython.core.display.HTML object>"
816
+ },
817
+ "metadata": {}
818
+ },
819
+ {
820
+ "output_type": "display_data",
821
+ "data": {
822
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>",
823
+ "text/plain": "<IPython.core.display.HTML object>"
824
+ },
825
+ "metadata": {}
826
+ },
827
+ {
828
+ "output_type": "display_data",
829
+ "data": {
830
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>",
831
+ "text/plain": "<IPython.core.display.HTML object>"
832
+ },
833
+ "metadata": {}
834
+ },
835
+ {
836
+ "output_type": "display_data",
837
+ "data": {
838
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc</a>",
839
+ "text/plain": "<IPython.core.display.HTML object>"
840
+ },
841
+ "metadata": {}
842
+ },
843
+ {
844
+ "output_type": "display_data",
845
+ "data": {
846
+ "text/html": "\n <div>\n \n <progress value='3972' max='11913' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 3972/11913 24:20 < 48:40, 2.72 it/s, Epoch 1/3]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>",
847
+ "text/plain": "<IPython.core.display.HTML object>"
848
+ },
849
+ "metadata": {}
850
+ },
851
+ {
852
+ "output_type": "stream",
853
+ "name": "stderr",
854
+ "text": "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.6s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 22.7s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 14.0s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 15.2s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 14.0s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 12.4s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 13.4s\n"
855
+ }
856
+ ],
857
+ "execution_count": null,
858
+ "metadata": {
859
+ "datalore": {
860
+ "hide_input_from_viewers": true,
861
+ "hide_output_from_viewers": true,
862
+ "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
863
+ "type": "CODE"
864
+ },
865
+ "gather": {
866
+ "logged": 1706449934637
867
+ }
868
+ }
869
+ },
870
+ {
871
+ "cell_type": "markdown",
872
+ "source": [
873
+ "### Evaluation"
874
+ ],
875
+ "metadata": {}
876
+ },
877
+ {
878
+ "cell_type": "markdown",
879
+ "source": [
880
+ "We instantiate a classifier `pipeline` and push it to CUDA."
881
+ ],
882
+ "metadata": {}
883
+ },
884
+ {
885
+ "cell_type": "code",
886
+ "source": [
887
+ "classifier = pipeline(\"text-classification\", \n",
888
+ " model, \n",
889
+ " tokenizer=tokenizer, \n",
890
+ " device=\"cuda:0\")"
891
+ ],
892
+ "outputs": [],
893
+ "execution_count": null,
894
+ "metadata": {
895
+ "datalore": {
896
+ "hide_input_from_viewers": true,
897
+ "hide_output_from_viewers": true,
898
+ "node_id": "kHoUdBeqcyVXDSGv54C4aE",
899
+ "type": "CODE"
900
+ },
901
+ "gather": {
902
+ "logged": 1706411459928
903
+ }
904
+ }
905
+ },
906
+ {
907
+ "cell_type": "markdown",
908
+ "source": [
909
+ "We use the same tokenizer used for training to tokenize/encode the validation set."
910
+ ],
911
+ "metadata": {}
912
+ },
913
+ {
914
+ "cell_type": "code",
915
+ "source": [
916
+ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
917
+ " max_length=None, \n",
918
+ " padding='max_length', \n",
919
+ " return_token_type_ids=True, \n",
920
+ " truncation=True)"
921
+ ],
922
+ "outputs": [],
923
+ "execution_count": null,
924
+ "metadata": {
925
+ "datalore": {
926
+ "hide_input_from_viewers": true,
927
+ "hide_output_from_viewers": true,
928
+ "node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
929
+ "type": "CODE"
930
+ },
931
+ "gather": {
932
+ "logged": 1706411523285
933
+ }
934
+ }
935
+ },
936
+ {
937
+ "cell_type": "markdown",
938
+ "source": [
939
+ "Once we've made the data loadable by putting it into a `DataLoader`, we "
940
+ ],
941
+ "metadata": {}
942
+ },
943
+ {
944
+ "cell_type": "code",
945
+ "source": [
946
+ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
947
+ " torch.tensor(test_encodings['attention_mask']), \n",
948
+ " torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
949
+ " torch.tensor(test_encodings['token_type_ids']))\n",
950
+ "test_dataloader = torch.utils.data.DataLoader(test_data, \n",
951
+ " sampler=torch.utils.data.SequentialSampler(test_data), \n",
952
+ " batch_size=BATCH_SIZE)"
953
+ ],
954
+ "outputs": [],
955
+ "execution_count": null,
956
+ "metadata": {
957
+ "datalore": {
958
+ "hide_input_from_viewers": true,
959
+ "hide_output_from_viewers": true,
960
+ "node_id": "MWfGq2tTkJNzFiDoUPq2X7",
961
+ "type": "CODE"
962
+ },
963
+ "gather": {
964
+ "logged": 1706411543379
965
+ }
966
+ }
967
+ },
968
+ {
969
+ "cell_type": "code",
970
+ "source": [
971
+ "model.eval()\n",
972
+ "\n",
973
+ "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
974
+ "\n",
975
+ "for i, batch in enumerate(test_dataloader):\n",
976
+ " batch = tuple(t.to(device) for t in batch)\n",
977
+ " \n",
978
+ " # Unpack the inputs from our dataloader\n",
979
+ " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
980
+ " \n",
981
+ " with torch.no_grad():\n",
982
+ " outs = model(b_input_ids, attention_mask=b_input_mask)\n",
983
+ " b_logit_pred = outs[0]\n",
984
+ " pred_label = torch.sigmoid(b_logit_pred)\n",
985
+ "\n",
986
+ " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
987
+ " pred_label = pred_label.to('cpu').numpy()\n",
988
+ " b_labels = b_labels.to('cpu').numpy()\n",
989
+ "\n",
990
+ " tokenized_texts.append(b_input_ids)\n",
991
+ " logit_preds.append(b_logit_pred)\n",
992
+ " true_labels.append(b_labels)\n",
993
+ " pred_labels.append(pred_label)\n",
994
+ "\n",
995
+ "# Flatten outputs\n",
996
+ "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
997
+ "pred_labels = [item for sublist in pred_labels for item in sublist]\n",
998
+ "true_labels = [item for sublist in true_labels for item in sublist]\n",
999
+ "\n",
1000
+ "# Converting flattened binary values to boolean values\n",
1001
+ "true_bools = [tl == 1 for tl in true_labels]\n",
1002
+ "pred_bools = [pl > 0.50 for pl in pred_labels] "
1003
+ ],
1004
+ "outputs": [],
1005
+ "execution_count": null,
1006
+ "metadata": {
1007
+ "datalore": {
1008
+ "hide_input_from_viewers": true,
1009
+ "hide_output_from_viewers": true,
1010
+ "node_id": "1SJCSrQTRCexFCNCIyRrzL",
1011
+ "type": "CODE"
1012
+ },
1013
+ "gather": {
1014
+ "logged": 1706411587843
1015
+ }
1016
+ }
1017
+ },
1018
+ {
1019
+ "cell_type": "markdown",
1020
+ "source": [
1021
+ "We create a classification report:"
1022
+ ],
1023
+ "metadata": {}
1024
+ },
1025
+ {
1026
+ "cell_type": "code",
1027
+ "source": [
1028
+ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
1029
+ "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
1030
+ "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
1031
+ "print(clf_report)"
1032
+ ],
1033
+ "outputs": [],
1034
+ "execution_count": null,
1035
+ "metadata": {
1036
+ "datalore": {
1037
+ "hide_input_from_viewers": true,
1038
+ "hide_output_from_viewers": true,
1039
+ "node_id": "eBprrgF086mznPbPVBpOLS",
1040
+ "type": "CODE"
1041
+ },
1042
+ "gather": {
1043
+ "logged": 1706411588249
1044
+ }
1045
+ }
1046
+ },
1047
+ {
1048
+ "cell_type": "markdown",
1049
+ "source": [
1050
+ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
1051
+ ],
1052
+ "metadata": {}
1053
+ },
1054
+ {
1055
+ "cell_type": "code",
1056
+ "source": [
1057
+ "# Creating a map of class names from class numbers\n",
1058
+ "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
1059
+ ],
1060
+ "outputs": [],
1061
+ "execution_count": null,
1062
+ "metadata": {
1063
+ "datalore": {
1064
+ "hide_input_from_viewers": true,
1065
+ "hide_output_from_viewers": true,
1066
+ "node_id": "yELHY0IEwMlMw3x6e7hoD1",
1067
+ "type": "CODE"
1068
+ },
1069
+ "gather": {
1070
+ "logged": 1706411588638
1071
+ }
1072
+ }
1073
+ },
1074
+ {
1075
+ "cell_type": "code",
1076
+ "source": [
1077
+ "true_label_idxs, pred_label_idxs = [], []\n",
1078
+ "\n",
1079
+ "for vals in true_bools:\n",
1080
+ " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
1081
+ "for vals in pred_bools:\n",
1082
+ " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
1083
+ ],
1084
+ "outputs": [],
1085
+ "execution_count": null,
1086
+ "metadata": {
1087
+ "datalore": {
1088
+ "hide_input_from_viewers": true,
1089
+ "hide_output_from_viewers": true,
1090
+ "node_id": "jH0S35dDteUch01sa6me6e",
1091
+ "type": "CODE"
1092
+ },
1093
+ "gather": {
1094
+ "logged": 1706411589004
1095
+ }
1096
+ }
1097
+ },
1098
+ {
1099
+ "cell_type": "code",
1100
+ "source": [
1101
+ "true_label_texts, pred_label_texts = [], []\n",
1102
+ "\n",
1103
+ "for vals in true_label_idxs:\n",
1104
+ " if vals:\n",
1105
+ " true_label_texts.append([idx2label[val] for val in vals])\n",
1106
+ " else:\n",
1107
+ " true_label_texts.append(vals)\n",
1108
+ "\n",
1109
+ "for vals in pred_label_idxs:\n",
1110
+ " if vals:\n",
1111
+ " pred_label_texts.append([idx2label[val] for val in vals])\n",
1112
+ " else:\n",
1113
+ " pred_label_texts.append(vals)"
1114
+ ],
1115
+ "outputs": [],
1116
+ "execution_count": null,
1117
+ "metadata": {
1118
+ "datalore": {
1119
+ "hide_input_from_viewers": true,
1120
+ "hide_output_from_viewers": true,
1121
+ "node_id": "h4vHL8XdGpayZ6xLGJUF6F",
1122
+ "type": "CODE"
1123
+ },
1124
+ "gather": {
1125
+ "logged": 1706411589301
1126
+ }
1127
+ }
1128
+ },
1129
+ {
1130
+ "cell_type": "code",
1131
+ "source": [
1132
+ "symptom_texts = [tokenizer.decode(text,\n",
1133
+ " skip_special_tokens=True,\n",
1134
+ " clean_up_tokenization_spaces=False) for text in tokenized_texts]"
1135
+ ],
1136
+ "outputs": [],
1137
+ "execution_count": null,
1138
+ "metadata": {
1139
+ "datalore": {
1140
+ "hide_input_from_viewers": true,
1141
+ "hide_output_from_viewers": true,
1142
+ "node_id": "SxUmVHfQISEeptg1SawOmB",
1143
+ "type": "CODE"
1144
+ },
1145
+ "gather": {
1146
+ "logged": 1706411591952
1147
+ }
1148
+ }
1149
+ },
1150
+ {
1151
+ "cell_type": "code",
1152
+ "source": [
1153
+ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
1154
+ " 'true_labels': true_label_texts, \n",
1155
+ " 'pred_labels':pred_label_texts})\n",
1156
+ "comparisons_df.to_csv('comparisons.csv')\n",
1157
+ "comparisons_df"
1158
+ ],
1159
+ "outputs": [],
1160
+ "execution_count": null,
1161
+ "metadata": {
1162
+ "datalore": {
1163
+ "hide_input_from_viewers": true,
1164
+ "hide_output_from_viewers": true,
1165
+ "node_id": "BxFNigNGRLTOqraI55BPSH",
1166
+ "type": "CODE"
1167
+ },
1168
+ "gather": {
1169
+ "logged": 1706411592512
1170
+ }
1171
+ }
1172
+ }
1173
+ ],
1174
+ "metadata": {
1175
+ "datalore": {
1176
+ "base_environment": "default",
1177
+ "computation_mode": "JUPYTER",
1178
+ "package_manager": "pip",
1179
+ "packages": [
1180
+ {
1181
+ "name": "datasets",
1182
+ "source": "PIP",
1183
+ "version": "2.16.1"
1184
+ },
1185
+ {
1186
+ "name": "torch",
1187
+ "source": "PIP",
1188
+ "version": "2.1.2"
1189
+ },
1190
+ {
1191
+ "name": "accelerate",
1192
+ "source": "PIP",
1193
+ "version": "0.26.1"
1194
+ }
1195
+ ],
1196
+ "report_row_ids": [
1197
+ "un8W7ez7ZwoGb5Co6nydEV",
1198
+ "40nN9Hvgi1clHNV5RAemI5",
1199
+ "TgRD90H5NSPpKS41OeXI1w",
1200
+ "ZOm5BfUs3h1EGLaUkBGeEB",
1201
+ "kOP0CZWNSk6vqE3wkPp7Vc",
1202
+ "W4PWcOu2O2pRaZyoE2W80h",
1203
+ "RolbOnQLIftk0vy9mIcz5M",
1204
+ "8OPhUgbaNJmOdiq5D3a6vK",
1205
+ "5Qrt3jSvSrpK6Ne1hS6shL",
1206
+ "hTq7nFUrovN5Ao4u6dIYWZ",
1207
+ "I8WNZLpJ1DVP2wiCW7YBIB",
1208
+ "SawhU3I9BewSE1XBPstpNJ",
1209
+ "80EtLEl2FIE4FqbWnUD3nT"
1210
+ ],
1211
+ "version": 3
1212
+ },
1213
+ "kernelspec": {
1214
+ "display_name": "Python 3.8 - Pytorch and Tensorflow",
1215
+ "language": "python",
1216
+ "name": "python38-azureml-pt-tf"
1217
+ },
1218
+ "language_info": {
1219
+ "name": "python",
1220
+ "version": "3.8.5",
1221
+ "mimetype": "text/x-python",
1222
+ "codemirror_mode": {
1223
+ "name": "ipython",
1224
+ "version": 3
1225
+ },
1226
+ "pygments_lexer": "ipython3",
1227
+ "nbconvert_exporter": "python",
1228
+ "file_extension": ".py"
1229
+ },
1230
+ "microsoft": {
1231
+ "host": {
1232
+ "AzureML": {
1233
+ "notebookHasBeenCompleted": true
1234
+ }
1235
+ },
1236
+ "ms_spell_check": {
1237
+ "ms_spell_check_language": "en"
1238
+ }
1239
+ },
1240
+ "nteract": {
1241
+ "version": "nteract-front-end@1.0.0"
1242
+ }
1243
+ },
1244
+ "nbformat": 4,
1245
+ "nbformat_minor": 4
1246
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-20-56-58Z.ipynb ADDED
@@ -0,0 +1,993 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
8
+ "\n",
9
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "nteract": {
17
+ "transient": {
18
+ "deleting": false
19
+ }
20
+ },
21
+ "tags": []
22
+ },
23
+ "outputs": [],
24
+ "source": [
25
+ "# %pip install accelerate -U"
26
+ ]
27
+ },
28
+ {
29
+ "cell_type": "code",
30
+ "execution_count": 2,
31
+ "metadata": {
32
+ "nteract": {
33
+ "transient": {
34
+ "deleting": false
35
+ }
36
+ }
37
+ },
38
+ "outputs": [
39
+ {
40
+ "name": "stdout",
41
+ "output_type": "stream",
42
+ "text": [
43
+ "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
44
+ "Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
45
+ "Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
46
+ "Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
47
+ "Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
48
+ "Requirement already satisfied: scikit-multilearn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.2.0)\n",
49
+ "Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
50
+ "Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
51
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
52
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
53
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
54
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
55
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
56
+ "Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
57
+ "Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
58
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
59
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
60
+ "Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
61
+ "Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
62
+ "Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
63
+ "Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
64
+ "Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
65
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
66
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
67
+ "Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
68
+ "Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
69
+ "Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
70
+ "Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
71
+ "Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
72
+ "Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
73
+ "Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
74
+ "Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
75
+ "Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
76
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
77
+ "Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
78
+ "Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
79
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
80
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
81
+ "Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
82
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
83
+ "Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
84
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
85
+ "Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
86
+ "Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
87
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
88
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
89
+ "Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
90
+ "Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
91
+ "Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
92
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
93
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
94
+ "Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
95
+ "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
96
+ "Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
97
+ "Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
98
+ "Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
99
+ "Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
100
+ "Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
101
+ "Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
102
+ "Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
103
+ "Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
104
+ "Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
105
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
106
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
107
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
108
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
109
+ "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
110
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n",
111
+ "Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
112
+ "Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
113
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
114
+ "Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
115
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
116
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
117
+ "Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
118
+ "Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
119
+ "Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
120
+ "Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
121
+ "Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
122
+ "Note: you may need to restart the kernel to use updated packages.\n"
123
+ ]
124
+ }
125
+ ],
126
+ "source": [
127
+ "%pip install transformers datasets shap watermark wandb scikit-multilearn evaluate codecarbon"
128
+ ]
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "execution_count": 3,
133
+ "metadata": {
134
+ "datalore": {
135
+ "hide_input_from_viewers": false,
136
+ "hide_output_from_viewers": false,
137
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
138
+ "report_properties": {
139
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
140
+ },
141
+ "type": "CODE"
142
+ },
143
+ "gather": {
144
+ "logged": 1706454586481
145
+ },
146
+ "tags": []
147
+ },
148
+ "outputs": [
149
+ {
150
+ "name": "stderr",
151
+ "output_type": "stream",
152
+ "text": [
153
+ "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
154
+ " from .autonotebook import tqdm as notebook_tqdm\n",
155
+ "2024-01-28 19:47:15.508449: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
156
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
157
+ "2024-01-28 19:47:16.502791: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
158
+ "2024-01-28 19:47:16.502915: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
159
+ "2024-01-28 19:47:16.502928: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
160
+ ]
161
+ }
162
+ ],
163
+ "source": [
164
+ "import pandas as pd\n",
165
+ "import numpy as np\n",
166
+ "import torch\n",
167
+ "import os\n",
168
+ "from typing import List, Union\n",
169
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
170
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
171
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
172
+ "from pyarrow import Table\n",
173
+ "import shap\n",
174
+ "import wandb\n",
175
+ "import evaluate\n",
176
+ "from codecarbon import EmissionsTracker\n",
177
+ "\n",
178
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
179
+ "tracker = EmissionsTracker()\n",
180
+ "\n",
181
+ "%load_ext watermark"
182
+ ]
183
+ },
184
+ {
185
+ "cell_type": "code",
186
+ "execution_count": 4,
187
+ "metadata": {
188
+ "collapsed": false,
189
+ "gather": {
190
+ "logged": 1706454586654
191
+ },
192
+ "jupyter": {
193
+ "outputs_hidden": false
194
+ }
195
+ },
196
+ "outputs": [],
197
+ "source": [
198
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
199
+ "\n",
200
+ "SEED: int = 42\n",
201
+ "\n",
202
+ "BATCH_SIZE: int = 16\n",
203
+ "EPOCHS: int = 3\n",
204
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
205
+ "\n",
206
+ "# WandB configuration\n",
207
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
208
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
209
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
210
+ ]
211
+ },
212
+ {
213
+ "cell_type": "code",
214
+ "execution_count": 5,
215
+ "metadata": {
216
+ "collapsed": false,
217
+ "jupyter": {
218
+ "outputs_hidden": false
219
+ }
220
+ },
221
+ "outputs": [
222
+ {
223
+ "name": "stdout",
224
+ "output_type": "stream",
225
+ "text": [
226
+ "numpy : 1.23.5\n",
227
+ "re : 2.2.1\n",
228
+ "evaluate: 0.4.1\n",
229
+ "pandas : 2.0.2\n",
230
+ "wandb : 0.16.2\n",
231
+ "shap : 0.44.1\n",
232
+ "torch : 1.12.0\n",
233
+ "logging : 0.5.1.2\n",
234
+ "\n"
235
+ ]
236
+ }
237
+ ],
238
+ "source": [
239
+ "%watermark --iversion"
240
+ ]
241
+ },
242
+ {
243
+ "cell_type": "code",
244
+ "execution_count": 6,
245
+ "metadata": {
246
+ "datalore": {
247
+ "hide_input_from_viewers": true,
248
+ "hide_output_from_viewers": true,
249
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
250
+ "type": "CODE"
251
+ }
252
+ },
253
+ "outputs": [
254
+ {
255
+ "name": "stdout",
256
+ "output_type": "stream",
257
+ "text": [
258
+ "Sun Jan 28 19:47:19 2024 \n",
259
+ "+---------------------------------------------------------------------------------------+\n",
260
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
261
+ "|-----------------------------------------+----------------------+----------------------+\n",
262
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
263
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
264
+ "| | | MIG M. |\n",
265
+ "|=========================================+======================+======================|\n",
266
+ "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
267
+ "| N/A 29C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
268
+ "| | | N/A |\n",
269
+ "+-----------------------------------------+----------------------+----------------------+\n",
270
+ "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
271
+ "| N/A 29C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
272
+ "| | | N/A |\n",
273
+ "+-----------------------------------------+----------------------+----------------------+\n",
274
+ " \n",
275
+ "+---------------------------------------------------------------------------------------+\n",
276
+ "| Processes: |\n",
277
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
278
+ "| ID ID Usage |\n",
279
+ "|=======================================================================================|\n",
280
+ "| No running processes found |\n",
281
+ "+---------------------------------------------------------------------------------------+\n"
282
+ ]
283
+ }
284
+ ],
285
+ "source": [
286
+ "!nvidia-smi"
287
+ ]
288
+ },
289
+ {
290
+ "cell_type": "markdown",
291
+ "metadata": {
292
+ "datalore": {
293
+ "hide_input_from_viewers": false,
294
+ "hide_output_from_viewers": false,
295
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
296
+ "report_properties": {
297
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
298
+ },
299
+ "type": "MD"
300
+ }
301
+ },
302
+ "source": [
303
+ "## Loading the data set"
304
+ ]
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "execution_count": 7,
309
+ "metadata": {
310
+ "collapsed": false,
311
+ "gather": {
312
+ "logged": 1706449040507
313
+ },
314
+ "jupyter": {
315
+ "outputs_hidden": false
316
+ }
317
+ },
318
+ "outputs": [],
319
+ "source": [
320
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
321
+ ]
322
+ },
323
+ {
324
+ "cell_type": "code",
325
+ "execution_count": 8,
326
+ "metadata": {
327
+ "collapsed": false,
328
+ "gather": {
329
+ "logged": 1706449044205
330
+ },
331
+ "jupyter": {
332
+ "outputs_hidden": false,
333
+ "source_hidden": false
334
+ },
335
+ "nteract": {
336
+ "transient": {
337
+ "deleting": false
338
+ }
339
+ }
340
+ },
341
+ "outputs": [
342
+ {
343
+ "data": {
344
+ "text/plain": [
345
+ "DatasetDict({\n",
346
+ " train: Dataset({\n",
347
+ " features: ['id', 'text', 'label'],\n",
348
+ " num_rows: 1270444\n",
349
+ " })\n",
350
+ " test: Dataset({\n",
351
+ " features: ['id', 'text', 'label'],\n",
352
+ " num_rows: 272238\n",
353
+ " })\n",
354
+ " val: Dataset({\n",
355
+ " features: ['id', 'text', 'label'],\n",
356
+ " num_rows: 272238\n",
357
+ " })\n",
358
+ "})"
359
+ ]
360
+ },
361
+ "execution_count": 8,
362
+ "metadata": {},
363
+ "output_type": "execute_result"
364
+ }
365
+ ],
366
+ "source": [
367
+ "dataset"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 9,
373
+ "metadata": {},
374
+ "outputs": [],
375
+ "source": [
376
+ "SUBSAMPLING = 0.1\n",
377
+ "\n",
378
+ "if SUBSAMPLING < 1:\n",
379
+ " _ = DatasetDict()\n",
380
+ " for each in dataset.keys():\n",
381
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
382
+ "\n",
383
+ " dataset = _"
384
+ ]
385
+ },
386
+ {
387
+ "cell_type": "markdown",
388
+ "metadata": {},
389
+ "source": [
390
+ "## Tokenisation and encoding"
391
+ ]
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "execution_count": 10,
396
+ "metadata": {},
397
+ "outputs": [],
398
+ "source": [
399
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
400
+ " return ds_enc"
401
+ ]
402
+ },
403
+ {
404
+ "cell_type": "markdown",
405
+ "metadata": {},
406
+ "source": [
407
+ "## Evaluation metrics"
408
+ ]
409
+ },
410
+ {
411
+ "cell_type": "code",
412
+ "execution_count": 11,
413
+ "metadata": {},
414
+ "outputs": [],
415
+ "source": [
416
+ "accuracy = evaluate.load(\"accuracy\")\n",
417
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
418
+ "f1 = evaluate.load(\"f1\")"
419
+ ]
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "execution_count": 12,
424
+ "metadata": {},
425
+ "outputs": [],
426
+ "source": [
427
+ "def compute_metrics(eval_pred):\n",
428
+ " predictions, labels = eval_pred\n",
429
+ " predictions = np.argmax(predictions, axis=1)\n",
430
+ " return {\n",
431
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
432
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
433
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
434
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
435
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
436
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
437
+ " }"
438
+ ]
439
+ },
440
+ {
441
+ "cell_type": "markdown",
442
+ "metadata": {},
443
+ "source": [
444
+ "## Training"
445
+ ]
446
+ },
447
+ {
448
+ "cell_type": "markdown",
449
+ "metadata": {},
450
+ "source": [
451
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "code",
456
+ "execution_count": 13,
457
+ "metadata": {},
458
+ "outputs": [],
459
+ "source": [
460
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
461
+ ]
462
+ },
463
+ {
464
+ "cell_type": "code",
465
+ "execution_count": 14,
466
+ "metadata": {},
467
+ "outputs": [
468
+ {
469
+ "name": "stderr",
470
+ "output_type": "stream",
471
+ "text": [
472
+ "Map: 100%|██████████| 127044/127044 [00:53<00:00, 2384.54 examples/s]\n",
473
+ "Map: 100%|██████████| 27223/27223 [00:11<00:00, 2396.71 examples/s]\n",
474
+ "Map: 100%|██████████| 27223/27223 [00:11<00:00, 2375.38 examples/s]\n",
475
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
476
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
477
+ ]
478
+ }
479
+ ],
480
+ "source": [
481
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
482
+ "\n",
483
+ "cols = dataset[\"train\"].column_names\n",
484
+ "cols.remove(\"label\")\n",
485
+ "ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n",
486
+ "\n",
487
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
488
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
489
+ " id2label=label_map, \n",
490
+ " label2id={v:k for k,v in label_map.items()})\n",
491
+ "\n",
492
+ "args = TrainingArguments(\n",
493
+ " output_dir=\"vaers\",\n",
494
+ " evaluation_strategy=\"epoch\",\n",
495
+ " save_strategy=\"epoch\",\n",
496
+ " learning_rate=2e-5,\n",
497
+ " per_device_train_batch_size=BATCH_SIZE,\n",
498
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
499
+ " num_train_epochs=EPOCHS,\n",
500
+ " weight_decay=.01,\n",
501
+ " logging_steps=1,\n",
502
+ " load_best_model_at_end=True,\n",
503
+ " run_name=f\"daedra-training\",\n",
504
+ " report_to=[\"wandb\"])\n",
505
+ "\n",
506
+ "trainer = Trainer(\n",
507
+ " model=model,\n",
508
+ " args=args,\n",
509
+ " train_dataset=ds_enc[\"train\"],\n",
510
+ " eval_dataset=ds_enc[\"test\"],\n",
511
+ " tokenizer=tokenizer,\n",
512
+ " compute_metrics=compute_metrics)"
513
+ ]
514
+ },
515
+ {
516
+ "cell_type": "code",
517
+ "execution_count": 15,
518
+ "metadata": {},
519
+ "outputs": [
520
+ {
521
+ "name": "stderr",
522
+ "output_type": "stream",
523
+ "text": [
524
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
525
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
526
+ ]
527
+ },
528
+ {
529
+ "data": {
530
+ "text/html": [
531
+ "Tracking run with wandb version 0.16.2"
532
+ ],
533
+ "text/plain": [
534
+ "<IPython.core.display.HTML object>"
535
+ ]
536
+ },
537
+ "metadata": {},
538
+ "output_type": "display_data"
539
+ },
540
+ {
541
+ "data": {
542
+ "text/html": [
543
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_194842-yvxddyg6</code>"
544
+ ],
545
+ "text/plain": [
546
+ "<IPython.core.display.HTML object>"
547
+ ]
548
+ },
549
+ "metadata": {},
550
+ "output_type": "display_data"
551
+ },
552
+ {
553
+ "data": {
554
+ "text/html": [
555
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
556
+ ],
557
+ "text/plain": [
558
+ "<IPython.core.display.HTML object>"
559
+ ]
560
+ },
561
+ "metadata": {},
562
+ "output_type": "display_data"
563
+ },
564
+ {
565
+ "data": {
566
+ "text/html": [
567
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
568
+ ],
569
+ "text/plain": [
570
+ "<IPython.core.display.HTML object>"
571
+ ]
572
+ },
573
+ "metadata": {},
574
+ "output_type": "display_data"
575
+ },
576
+ {
577
+ "data": {
578
+ "text/html": [
579
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6</a>"
580
+ ],
581
+ "text/plain": [
582
+ "<IPython.core.display.HTML object>"
583
+ ]
584
+ },
585
+ "metadata": {},
586
+ "output_type": "display_data"
587
+ },
588
+ {
589
+ "data": {
590
+ "text/html": [
591
+ "Finishing last run (ID:yvxddyg6) before initializing another..."
592
+ ],
593
+ "text/plain": [
594
+ "<IPython.core.display.HTML object>"
595
+ ]
596
+ },
597
+ "metadata": {},
598
+ "output_type": "display_data"
599
+ },
600
+ {
601
+ "data": {
602
+ "text/html": [
603
+ " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
604
+ ],
605
+ "text/plain": [
606
+ "<IPython.core.display.HTML object>"
607
+ ]
608
+ },
609
+ "metadata": {},
610
+ "output_type": "display_data"
611
+ },
612
+ {
613
+ "data": {
614
+ "text/html": [
615
+ "Find logs at: <code>./wandb/run-20240128_194842-yvxddyg6/logs</code>"
616
+ ],
617
+ "text/plain": [
618
+ "<IPython.core.display.HTML object>"
619
+ ]
620
+ },
621
+ "metadata": {},
622
+ "output_type": "display_data"
623
+ },
624
+ {
625
+ "data": {
626
+ "text/html": [
627
+ "Successfully finished last run (ID:yvxddyg6). Initializing new run:<br/>"
628
+ ],
629
+ "text/plain": [
630
+ "<IPython.core.display.HTML object>"
631
+ ]
632
+ },
633
+ "metadata": {},
634
+ "output_type": "display_data"
635
+ },
636
+ {
637
+ "data": {
638
+ "text/html": [
639
+ "Tracking run with wandb version 0.16.2"
640
+ ],
641
+ "text/plain": [
642
+ "<IPython.core.display.HTML object>"
643
+ ]
644
+ },
645
+ "metadata": {},
646
+ "output_type": "display_data"
647
+ },
648
+ {
649
+ "data": {
650
+ "text/html": [
651
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_194845-9g8te2gf</code>"
652
+ ],
653
+ "text/plain": [
654
+ "<IPython.core.display.HTML object>"
655
+ ]
656
+ },
657
+ "metadata": {},
658
+ "output_type": "display_data"
659
+ },
660
+ {
661
+ "data": {
662
+ "text/html": [
663
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
664
+ ],
665
+ "text/plain": [
666
+ "<IPython.core.display.HTML object>"
667
+ ]
668
+ },
669
+ "metadata": {},
670
+ "output_type": "display_data"
671
+ },
672
+ {
673
+ "data": {
674
+ "text/html": [
675
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
676
+ ],
677
+ "text/plain": [
678
+ "<IPython.core.display.HTML object>"
679
+ ]
680
+ },
681
+ "metadata": {},
682
+ "output_type": "display_data"
683
+ },
684
+ {
685
+ "data": {
686
+ "text/html": [
687
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf</a>"
688
+ ],
689
+ "text/plain": [
690
+ "<IPython.core.display.HTML object>"
691
+ ]
692
+ },
693
+ "metadata": {},
694
+ "output_type": "display_data"
695
+ },
696
+ {
697
+ "data": {
698
+ "text/html": [
699
+ "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
700
+ ],
701
+ "text/plain": [
702
+ "<wandb.sdk.wandb_run.Run at 0x7fb8b483bf40>"
703
+ ]
704
+ },
705
+ "execution_count": 15,
706
+ "metadata": {},
707
+ "output_type": "execute_result"
708
+ }
709
+ ],
710
+ "source": [
711
+ "if SUBSAMPLING != 1.0:\n",
712
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
713
+ "else:\n",
714
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
715
+ "\n",
716
+ "wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
717
+ "wandb_tag.append(f\"base:{model_ckpt}\")\n",
718
+ " \n",
719
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
720
+ ]
721
+ },
722
+ {
723
+ "cell_type": "code",
724
+ "execution_count": 16,
725
+ "metadata": {},
726
+ "outputs": [
727
+ {
728
+ "name": "stderr",
729
+ "output_type": "stream",
730
+ "text": [
731
+ "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
732
+ ]
733
+ },
734
+ {
735
+ "data": {
736
+ "text/html": [
737
+ "\n",
738
+ " <div>\n",
739
+ " \n",
740
+ " <progress value='7943' max='11913' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
741
+ " [ 7943/11913 43:43 < 21:51, 3.03 it/s, Epoch 2/3]\n",
742
+ " </div>\n",
743
+ " <table border=\"1\" class=\"dataframe\">\n",
744
+ " <thead>\n",
745
+ " <tr style=\"text-align: left;\">\n",
746
+ " <th>Epoch</th>\n",
747
+ " <th>Training Loss</th>\n",
748
+ " <th>Validation Loss</th>\n",
749
+ " <th>Accuracy</th>\n",
750
+ " <th>Precision Macroaverage</th>\n",
751
+ " <th>Precision Microaverage</th>\n",
752
+ " <th>Recall Macroaverage</th>\n",
753
+ " <th>Recall Microaverage</th>\n",
754
+ " <th>F1 Microaverage</th>\n",
755
+ " </tr>\n",
756
+ " </thead>\n",
757
+ " <tbody>\n",
758
+ " <tr>\n",
759
+ " <td>1</td>\n",
760
+ " <td>0.251300</td>\n",
761
+ " <td>0.362917</td>\n",
762
+ " <td>0.865775</td>\n",
763
+ " <td>0.701081</td>\n",
764
+ " <td>0.865775</td>\n",
765
+ " <td>0.556570</td>\n",
766
+ " <td>0.865775</td>\n",
767
+ " <td>0.865775</td>\n",
768
+ " </tr>\n",
769
+ " <tr>\n",
770
+ " <td>2</td>\n",
771
+ " <td>0.036000</td>\n",
772
+ " <td>0.352118</td>\n",
773
+ " <td>0.870551</td>\n",
774
+ " <td>0.728051</td>\n",
775
+ " <td>0.870551</td>\n",
776
+ " <td>0.609787</td>\n",
777
+ " <td>0.870551</td>\n",
778
+ " <td>0.870551</td>\n",
779
+ " </tr>\n",
780
+ " </tbody>\n",
781
+ "</table><p>"
782
+ ],
783
+ "text/plain": [
784
+ "<IPython.core.display.HTML object>"
785
+ ]
786
+ },
787
+ "metadata": {},
788
+ "output_type": "display_data"
789
+ },
790
+ {
791
+ "name": "stderr",
792
+ "output_type": "stream",
793
+ "text": [
794
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3971)... Done. 18.2s\n",
795
+ "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
796
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-7942)... "
797
+ ]
798
+ }
799
+ ],
800
+ "source": [
801
+ "tracker.start()\n",
802
+ "trainer.train()\n",
803
+ "tracker.stop()\n"
804
+ ]
805
+ },
806
+ {
807
+ "cell_type": "code",
808
+ "execution_count": null,
809
+ "metadata": {},
810
+ "outputs": [
811
+ {
812
+ "data": {
813
+ "text/html": [
814
+ "<style>\n",
815
+ " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
816
+ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
817
+ " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
818
+ " </style>\n",
819
+ "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>▁▇█</td></tr><tr><td>eval/f1_microaverage</td><td>▁▇█</td></tr><tr><td>eval/loss</td><td>█▃▁</td></tr><tr><td>eval/precision_macroaverage</td><td>▁▇█</td></tr><tr><td>eval/precision_microaverage</td><td>▁▇█</td></tr><tr><td>eval/recall_macroaverage</td><td>▁▇█</td></tr><tr><td>eval/recall_microaverage</td><td>▁▇█</td></tr><tr><td>eval/runtime</td><td>▁▃█</td></tr><tr><td>eval/samples_per_second</td><td>█▆▁</td></tr><tr><td>eval/steps_per_second</td><td>█▆▁</td></tr><tr><td>train/epoch</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>train/global_step</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>train/learning_rate</td><td>████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁</td></tr><tr><td>train/loss</td><td>█▅▆▆▅▄▄▃▆▅▃▃▅▄▆▄▄▄▂▄▄▅▄▃▄▄▁▄▂▂▃▃▃▂▂▃▂▃▃▂</td></tr><tr><td>train/total_flos</td><td>▁</td></tr><tr><td>train/train_loss</td><td>▁</td></tr><tr><td>train/train_runtime</td><td>▁</td></tr><tr><td>train/train_samples_per_second</td><td>▁</td></tr><tr><td>train/train_steps_per_second</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>0.84019</td></tr><tr><td>eval/f1_microaverage</td><td>0.84019</td></tr><tr><td>eval/loss</td><td>0.44011</td></tr><tr><td>eval/precision_macroaverage</td><td>0.415</td></tr><tr><td>eval/precision_microaverage</td><td>0.84019</td></tr><tr><td>eval/recall_macroaverage</td><td>0.40704</td></tr><tr><td>eval/recall_microaverage</td><td>0.84019</td></tr><tr><td>eval/runtime</td><td>10.0118</td></tr><tr><td>eval/samples_per_second</td><td>271.878</td></tr><tr><td>eval/steps_per_second</td><td>8.59</td></tr><tr><td>train/epoch</td><td>3.0</td></tr><tr><td>train/global_step</td><td>1191</td></tr><tr><td>train/learning_rate</td><td>0.0</td></tr><tr><td>train/loss</td><td>0.1782</td></tr><tr><td>train/total_flos</td><td>4885522962505728.0</td></tr><tr><td>train/train_loss</td><td>0.4724</td></tr><tr><td>train/train_runtime</td><td>483.5027</td></tr><tr><td>train/train_samples_per_second</td><td>78.825</td></tr><tr><td>train/train_steps_per_second</td><td>2.463</td></tr></table><br/></div></div>"
820
+ ],
821
+ "text/plain": [
822
+ "<IPython.core.display.HTML object>"
823
+ ]
824
+ },
825
+ "metadata": {},
826
+ "output_type": "display_data"
827
+ },
828
+ {
829
+ "data": {
830
+ "text/html": [
831
+ " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/3xvt3c2y' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/3xvt3c2y</a><br/>Synced 5 W&B file(s), 0 media file(s), 40 artifact file(s) and 0 other file(s)"
832
+ ],
833
+ "text/plain": [
834
+ "<IPython.core.display.HTML object>"
835
+ ]
836
+ },
837
+ "metadata": {},
838
+ "output_type": "display_data"
839
+ },
840
+ {
841
+ "data": {
842
+ "text/html": [
843
+ "Find logs at: <code>./wandb/run-20240128_192000-3xvt3c2y/logs</code>"
844
+ ],
845
+ "text/plain": [
846
+ "<IPython.core.display.HTML object>"
847
+ ]
848
+ },
849
+ "metadata": {},
850
+ "output_type": "display_data"
851
+ }
852
+ ],
853
+ "source": [
854
+ "wandb.finish()"
855
+ ]
856
+ },
857
+ {
858
+ "cell_type": "code",
859
+ "execution_count": null,
860
+ "metadata": {},
861
+ "outputs": [
862
+ {
863
+ "data": {
864
+ "text/plain": [
865
+ "CommitInfo(commit_url='https://huggingface.co/chrisvoncsefalvay/daedra/commit/c482ca6c8520142a3e67df4be25a408e6b557053', commit_message='DAEDRA model trained on 1.0% of the full sample of the VAERS dataset (training set size: 12,704)', commit_description='', oid='c482ca6c8520142a3e67df4be25a408e6b557053', pr_url=None, pr_revision=None, pr_num=None)"
866
+ ]
867
+ },
868
+ "execution_count": 31,
869
+ "metadata": {},
870
+ "output_type": "execute_result"
871
+ }
872
+ ],
873
+ "source": [
874
+ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
875
+ "tokenizer._tokenizer.save(\"tokenizer.json\")\n",
876
+ "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
877
+ "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
878
+ "\n",
879
+ "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
880
+ " variant=variant,\n",
881
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
882
+ ]
883
+ },
884
+ {
885
+ "cell_type": "code",
886
+ "execution_count": null,
887
+ "metadata": {},
888
+ "outputs": [],
889
+ "source": [
890
+ "from collections import Counter\n",
891
+ "\n",
892
+ "def get_most_frequent_unknown_tokens(tokenizer, dataset):\n",
893
+ " unknown_tokens = []\n",
894
+ " \n",
895
+ " # Tokenize each text in the dataset\n",
896
+ " for example in dataset:\n",
897
+ " tokens = tokenizer.tokenize(example['text'])\n",
898
+ " \n",
899
+ " # Check if each token is the 'unknown' special token\n",
900
+ " for token in tokens:\n",
901
+ " if token == tokenizer.unk_token:\n",
902
+ " unknown_tokens.append(token)\n",
903
+ " \n",
904
+ " # Count the frequency of each unique unknown token\n",
905
+ " token_counts = Counter(unknown_tokens)\n",
906
+ " \n",
907
+ " # Sort the tokens based on their frequency in descending order\n",
908
+ " most_frequent_tokens = token_counts.most_common()\n",
909
+ " \n",
910
+ " return most_frequent_tokens\n",
911
+ "\n",
912
+ "# Example usage\n",
913
+ "tokenizer = YourTokenizer() # Replace with your tokenizer\n",
914
+ "dataset = YourDataset() # Replace with your dataset\n",
915
+ "\n",
916
+ "most_frequent_unknown_tokens = get_most_frequent_unknown_tokens(tokenizer, dataset)\n",
917
+ "print(most_frequent_unknown_tokens)\n"
918
+ ]
919
+ }
920
+ ],
921
+ "metadata": {
922
+ "datalore": {
923
+ "base_environment": "default",
924
+ "computation_mode": "JUPYTER",
925
+ "package_manager": "pip",
926
+ "packages": [
927
+ {
928
+ "name": "datasets",
929
+ "source": "PIP",
930
+ "version": "2.16.1"
931
+ },
932
+ {
933
+ "name": "torch",
934
+ "source": "PIP",
935
+ "version": "2.1.2"
936
+ },
937
+ {
938
+ "name": "accelerate",
939
+ "source": "PIP",
940
+ "version": "0.26.1"
941
+ }
942
+ ],
943
+ "report_row_ids": [
944
+ "un8W7ez7ZwoGb5Co6nydEV",
945
+ "40nN9Hvgi1clHNV5RAemI5",
946
+ "TgRD90H5NSPpKS41OeXI1w",
947
+ "ZOm5BfUs3h1EGLaUkBGeEB",
948
+ "kOP0CZWNSk6vqE3wkPp7Vc",
949
+ "W4PWcOu2O2pRaZyoE2W80h",
950
+ "RolbOnQLIftk0vy9mIcz5M",
951
+ "8OPhUgbaNJmOdiq5D3a6vK",
952
+ "5Qrt3jSvSrpK6Ne1hS6shL",
953
+ "hTq7nFUrovN5Ao4u6dIYWZ",
954
+ "I8WNZLpJ1DVP2wiCW7YBIB",
955
+ "SawhU3I9BewSE1XBPstpNJ",
956
+ "80EtLEl2FIE4FqbWnUD3nT"
957
+ ],
958
+ "version": 3
959
+ },
960
+ "kernelspec": {
961
+ "display_name": "Python 3.8 - Pytorch and Tensorflow",
962
+ "language": "python",
963
+ "name": "python38-azureml-pt-tf"
964
+ },
965
+ "language_info": {
966
+ "codemirror_mode": {
967
+ "name": "ipython",
968
+ "version": 3
969
+ },
970
+ "file_extension": ".py",
971
+ "mimetype": "text/x-python",
972
+ "name": "python",
973
+ "nbconvert_exporter": "python",
974
+ "pygments_lexer": "ipython3",
975
+ "version": "3.8.5"
976
+ },
977
+ "microsoft": {
978
+ "host": {
979
+ "AzureML": {
980
+ "notebookHasBeenCompleted": true
981
+ }
982
+ },
983
+ "ms_spell_check": {
984
+ "ms_spell_check_language": "en"
985
+ }
986
+ },
987
+ "nteract": {
988
+ "version": "nteract-front-end@1.0.0"
989
+ }
990
+ },
991
+ "nbformat": 4,
992
+ "nbformat_minor": 4
993
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-23-54-39Z.ipynb ADDED
@@ -0,0 +1,692 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {}
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "source": [
15
+ "%pip install accelerate -U"
16
+ ],
17
+ "outputs": [
18
+ {
19
+ "output_type": "stream",
20
+ "name": "stdout",
21
+ "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
22
+ }
23
+ ],
24
+ "execution_count": 1,
25
+ "metadata": {
26
+ "nteract": {
27
+ "transient": {
28
+ "deleting": false
29
+ }
30
+ },
31
+ "tags": [],
32
+ "gather": {
33
+ "logged": 1706475754655
34
+ }
35
+ }
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "source": [
40
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
41
+ ],
42
+ "outputs": [
43
+ {
44
+ "output_type": "stream",
45
+ "name": "stdout",
46
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nNote: you may need to restart the kernel to use updated packages.\n"
47
+ }
48
+ ],
49
+ "execution_count": 2,
50
+ "metadata": {
51
+ "nteract": {
52
+ "transient": {
53
+ "deleting": false
54
+ }
55
+ }
56
+ }
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "import pandas as pd\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "import os\n",
65
+ "from typing import List, Union\n",
66
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
67
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
68
+ "import shap\n",
69
+ "import wandb\n",
70
+ "import evaluate\n",
71
+ "from codecarbon import EmissionsTracker\n",
72
+ "\n",
73
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
74
+ "tracker = EmissionsTracker()\n",
75
+ "\n",
76
+ "%load_ext watermark"
77
+ ],
78
+ "outputs": [
79
+ {
80
+ "output_type": "stream",
81
+ "name": "stderr",
82
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 21:14:33.562898: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 21:14:34.581816: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 21:14:34.581943: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 21:14:34.581956: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n[codecarbon INFO @ 21:14:37] [setup] RAM Tracking...\n[codecarbon INFO @ 21:14:37] [setup] GPU Tracking...\n[codecarbon INFO @ 21:14:37] Tracking Nvidia GPU via pynvml\n[codecarbon INFO @ 21:14:37] [setup] CPU Tracking...\n[codecarbon WARNING @ 21:14:37] No CPU tracking mode found. Falling back on CPU constant mode.\n[codecarbon WARNING @ 21:14:38] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n[codecarbon INFO @ 21:14:38] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 21:14:38] >>> Tracker's metadata:\n[codecarbon INFO @ 21:14:38] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n[codecarbon INFO @ 21:14:38] Python version: 3.8.5\n[codecarbon INFO @ 21:14:38] CodeCarbon version: 2.3.3\n[codecarbon INFO @ 21:14:38] Available RAM : 440.883 GB\n[codecarbon INFO @ 21:14:38] CPU count: 24\n[codecarbon INFO @ 21:14:38] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 21:14:38] GPU count: 4\n[codecarbon INFO @ 21:14:38] GPU model: 4 x Tesla V100-PCIE-16GB\n[codecarbon WARNING @ 21:14:38] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
83
+ }
84
+ ],
85
+ "execution_count": 3,
86
+ "metadata": {
87
+ "datalore": {
88
+ "hide_input_from_viewers": false,
89
+ "hide_output_from_viewers": false,
90
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
91
+ "report_properties": {
92
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
93
+ },
94
+ "type": "CODE"
95
+ },
96
+ "gather": {
97
+ "logged": 1706476478659
98
+ },
99
+ "tags": []
100
+ }
101
+ },
102
+ {
103
+ "cell_type": "code",
104
+ "source": [
105
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
106
+ "\n",
107
+ "SEED: int = 42\n",
108
+ "\n",
109
+ "BATCH_SIZE: int = 32\n",
110
+ "EPOCHS: int = 3\n",
111
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
112
+ "\n",
113
+ "# WandB configuration\n",
114
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
115
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
116
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
117
+ ],
118
+ "outputs": [],
119
+ "execution_count": 4,
120
+ "metadata": {
121
+ "collapsed": false,
122
+ "gather": {
123
+ "logged": 1706476478863
124
+ },
125
+ "jupyter": {
126
+ "outputs_hidden": false
127
+ }
128
+ }
129
+ },
130
+ {
131
+ "cell_type": "code",
132
+ "source": [
133
+ "%watermark --iversion"
134
+ ],
135
+ "outputs": [
136
+ {
137
+ "output_type": "stream",
138
+ "name": "stdout",
139
+ "text": "shap : 0.44.1\nre : 2.2.1\ntorch : 1.12.0\nevaluate: 0.4.1\nwandb : 0.16.2\nlogging : 0.5.1.2\npandas : 2.0.2\nnumpy : 1.23.5\n\n"
140
+ }
141
+ ],
142
+ "execution_count": 5,
143
+ "metadata": {
144
+ "collapsed": false,
145
+ "jupyter": {
146
+ "outputs_hidden": false
147
+ }
148
+ }
149
+ },
150
+ {
151
+ "cell_type": "code",
152
+ "source": [
153
+ "!nvidia-smi"
154
+ ],
155
+ "outputs": [
156
+ {
157
+ "output_type": "stream",
158
+ "name": "stdout",
159
+ "text": "Sun Jan 28 21:14:38 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
160
+ }
161
+ ],
162
+ "execution_count": 6,
163
+ "metadata": {
164
+ "datalore": {
165
+ "hide_input_from_viewers": true,
166
+ "hide_output_from_viewers": true,
167
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
168
+ "type": "CODE"
169
+ }
170
+ }
171
+ },
172
+ {
173
+ "cell_type": "markdown",
174
+ "source": [
175
+ "## Loading the data set"
176
+ ],
177
+ "metadata": {
178
+ "datalore": {
179
+ "hide_input_from_viewers": false,
180
+ "hide_output_from_viewers": false,
181
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
182
+ "report_properties": {
183
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
184
+ },
185
+ "type": "MD"
186
+ }
187
+ }
188
+ },
189
+ {
190
+ "cell_type": "code",
191
+ "source": [
192
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
193
+ ],
194
+ "outputs": [],
195
+ "execution_count": 7,
196
+ "metadata": {
197
+ "collapsed": false,
198
+ "gather": {
199
+ "logged": 1706476480469
200
+ },
201
+ "jupyter": {
202
+ "outputs_hidden": false
203
+ }
204
+ }
205
+ },
206
+ {
207
+ "cell_type": "code",
208
+ "source": [
209
+ "dataset"
210
+ ],
211
+ "outputs": [
212
+ {
213
+ "output_type": "execute_result",
214
+ "execution_count": 8,
215
+ "data": {
216
+ "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
217
+ },
218
+ "metadata": {}
219
+ }
220
+ ],
221
+ "execution_count": 8,
222
+ "metadata": {
223
+ "collapsed": false,
224
+ "gather": {
225
+ "logged": 1706476480629
226
+ },
227
+ "jupyter": {
228
+ "outputs_hidden": false,
229
+ "source_hidden": false
230
+ },
231
+ "nteract": {
232
+ "transient": {
233
+ "deleting": false
234
+ }
235
+ }
236
+ }
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "source": [
241
+ "SUBSAMPLING = 0.5\n",
242
+ "\n",
243
+ "if SUBSAMPLING < 1:\n",
244
+ " _ = DatasetDict()\n",
245
+ " for each in dataset.keys():\n",
246
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
247
+ "\n",
248
+ " dataset = _"
249
+ ],
250
+ "outputs": [],
251
+ "execution_count": 9,
252
+ "metadata": {
253
+ "gather": {
254
+ "logged": 1706476480826
255
+ }
256
+ }
257
+ },
258
+ {
259
+ "cell_type": "markdown",
260
+ "source": [
261
+ "## Tokenisation and encoding"
262
+ ],
263
+ "metadata": {}
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "source": [
268
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
269
+ " return ds_enc"
270
+ ],
271
+ "outputs": [],
272
+ "execution_count": 10,
273
+ "metadata": {
274
+ "gather": {
275
+ "logged": 1706476480944
276
+ }
277
+ }
278
+ },
279
+ {
280
+ "cell_type": "markdown",
281
+ "source": [
282
+ "## Evaluation metrics"
283
+ ],
284
+ "metadata": {}
285
+ },
286
+ {
287
+ "cell_type": "code",
288
+ "source": [
289
+ "accuracy = evaluate.load(\"accuracy\")\n",
290
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
291
+ "f1 = evaluate.load(\"f1\")"
292
+ ],
293
+ "outputs": [],
294
+ "execution_count": 11,
295
+ "metadata": {
296
+ "gather": {
297
+ "logged": 1706476481192
298
+ }
299
+ }
300
+ },
301
+ {
302
+ "cell_type": "code",
303
+ "source": [
304
+ "def compute_metrics(eval_pred):\n",
305
+ " predictions, labels = eval_pred\n",
306
+ " predictions = np.argmax(predictions, axis=1)\n",
307
+ " return {\n",
308
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
309
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
310
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
311
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
312
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
313
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
314
+ " }"
315
+ ],
316
+ "outputs": [],
317
+ "execution_count": 12,
318
+ "metadata": {
319
+ "gather": {
320
+ "logged": 1706476481346
321
+ }
322
+ }
323
+ },
324
+ {
325
+ "cell_type": "markdown",
326
+ "source": [
327
+ "## Training"
328
+ ],
329
+ "metadata": {}
330
+ },
331
+ {
332
+ "cell_type": "markdown",
333
+ "source": [
334
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
335
+ ],
336
+ "metadata": {}
337
+ },
338
+ {
339
+ "cell_type": "code",
340
+ "source": [
341
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
342
+ ],
343
+ "outputs": [],
344
+ "execution_count": 13,
345
+ "metadata": {
346
+ "gather": {
347
+ "logged": 1706476481593
348
+ }
349
+ }
350
+ },
351
+ {
352
+ "cell_type": "code",
353
+ "source": [
354
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
355
+ "\n",
356
+ "cols = dataset[\"train\"].column_names\n",
357
+ "cols.remove(\"label\")\n",
358
+ "ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n",
359
+ "\n",
360
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
361
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
362
+ " id2label=label_map, \n",
363
+ " label2id={v:k for k,v in label_map.items()})\n",
364
+ "\n",
365
+ "args = TrainingArguments(\n",
366
+ " output_dir=\"vaers\",\n",
367
+ " evaluation_strategy=\"epoch\",\n",
368
+ " save_strategy=\"epoch\",\n",
369
+ " learning_rate=2e-5,\n",
370
+ " per_device_train_batch_size=BATCH_SIZE,\n",
371
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
372
+ " num_train_epochs=EPOCHS,\n",
373
+ " weight_decay=.01,\n",
374
+ " logging_steps=1,\n",
375
+ " load_best_model_at_end=True,\n",
376
+ " run_name=f\"daedra-training\",\n",
377
+ " report_to=[\"wandb\"])\n",
378
+ "\n",
379
+ "trainer = Trainer(\n",
380
+ " model=model,\n",
381
+ " args=args,\n",
382
+ " train_dataset=ds_enc[\"train\"],\n",
383
+ " eval_dataset=ds_enc[\"test\"],\n",
384
+ " tokenizer=tokenizer,\n",
385
+ " compute_metrics=compute_metrics)"
386
+ ],
387
+ "outputs": [
388
+ {
389
+ "output_type": "stream",
390
+ "name": "stderr",
391
+ "text": "Map: 100%|██████████| 635222/635222 [04:25<00:00, 2395.47 examples/s]\nMap: 100%|██████████| 136119/136119 [00:56<00:00, 2405.75 examples/s]\nMap: 100%|██████████| 136119/136119 [00:56<00:00, 2422.27 examples/s]\nSome weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
392
+ }
393
+ ],
394
+ "execution_count": 14,
395
+ "metadata": {
396
+ "gather": {
397
+ "logged": 1706476861739
398
+ }
399
+ }
400
+ },
401
+ {
402
+ "cell_type": "code",
403
+ "source": [
404
+ "if SUBSAMPLING != 1.0:\n",
405
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
406
+ "else:\n",
407
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
408
+ "\n",
409
+ "wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
410
+ "wandb_tag.append(f\"base:{model_ckpt}\")\n",
411
+ " \n",
412
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
413
+ ],
414
+ "outputs": [
415
+ {
416
+ "output_type": "stream",
417
+ "name": "stderr",
418
+ "text": "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
419
+ },
420
+ {
421
+ "output_type": "display_data",
422
+ "data": {
423
+ "text/plain": "<IPython.core.display.HTML object>",
424
+ "text/html": "Tracking run with wandb version 0.16.2"
425
+ },
426
+ "metadata": {}
427
+ },
428
+ {
429
+ "output_type": "display_data",
430
+ "data": {
431
+ "text/plain": "<IPython.core.display.HTML object>",
432
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_212103-403j5ij5</code>"
433
+ },
434
+ "metadata": {}
435
+ },
436
+ {
437
+ "output_type": "display_data",
438
+ "data": {
439
+ "text/plain": "<IPython.core.display.HTML object>",
440
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
441
+ },
442
+ "metadata": {}
443
+ },
444
+ {
445
+ "output_type": "display_data",
446
+ "data": {
447
+ "text/plain": "<IPython.core.display.HTML object>",
448
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
449
+ },
450
+ "metadata": {}
451
+ },
452
+ {
453
+ "output_type": "display_data",
454
+ "data": {
455
+ "text/plain": "<IPython.core.display.HTML object>",
456
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5</a>"
457
+ },
458
+ "metadata": {}
459
+ },
460
+ {
461
+ "output_type": "display_data",
462
+ "data": {
463
+ "text/plain": "<IPython.core.display.HTML object>",
464
+ "text/html": "Finishing last run (ID:403j5ij5) before initializing another..."
465
+ },
466
+ "metadata": {}
467
+ },
468
+ {
469
+ "output_type": "display_data",
470
+ "data": {
471
+ "text/plain": "<IPython.core.display.HTML object>",
472
+ "text/html": " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
473
+ },
474
+ "metadata": {}
475
+ },
476
+ {
477
+ "output_type": "display_data",
478
+ "data": {
479
+ "text/plain": "<IPython.core.display.HTML object>",
480
+ "text/html": "Find logs at: <code>./wandb/run-20240128_212103-403j5ij5/logs</code>"
481
+ },
482
+ "metadata": {}
483
+ },
484
+ {
485
+ "output_type": "display_data",
486
+ "data": {
487
+ "text/plain": "<IPython.core.display.HTML object>",
488
+ "text/html": "Successfully finished last run (ID:403j5ij5). Initializing new run:<br/>"
489
+ },
490
+ "metadata": {}
491
+ },
492
+ {
493
+ "output_type": "display_data",
494
+ "data": {
495
+ "text/plain": "<IPython.core.display.HTML object>",
496
+ "text/html": "Tracking run with wandb version 0.16.2"
497
+ },
498
+ "metadata": {}
499
+ },
500
+ {
501
+ "output_type": "display_data",
502
+ "data": {
503
+ "text/plain": "<IPython.core.display.HTML object>",
504
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_212105-q65k78ea</code>"
505
+ },
506
+ "metadata": {}
507
+ },
508
+ {
509
+ "output_type": "display_data",
510
+ "data": {
511
+ "text/plain": "<IPython.core.display.HTML object>",
512
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
513
+ },
514
+ "metadata": {}
515
+ },
516
+ {
517
+ "output_type": "display_data",
518
+ "data": {
519
+ "text/plain": "<IPython.core.display.HTML object>",
520
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
521
+ },
522
+ "metadata": {}
523
+ },
524
+ {
525
+ "output_type": "display_data",
526
+ "data": {
527
+ "text/plain": "<IPython.core.display.HTML object>",
528
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea</a>"
529
+ },
530
+ "metadata": {}
531
+ },
532
+ {
533
+ "output_type": "execute_result",
534
+ "execution_count": 15,
535
+ "data": {
536
+ "text/html": "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>",
537
+ "text/plain": "<wandb.sdk.wandb_run.Run at 0x7fea31898d90>"
538
+ },
539
+ "metadata": {}
540
+ }
541
+ ],
542
+ "execution_count": 15,
543
+ "metadata": {
544
+ "gather": {
545
+ "logged": 1706476872191
546
+ }
547
+ }
548
+ },
549
+ {
550
+ "cell_type": "code",
551
+ "source": [
552
+ "tracker.start()\n",
553
+ "trainer.train()\n",
554
+ "tracker.stop()\n"
555
+ ],
556
+ "outputs": [
557
+ {
558
+ "output_type": "stream",
559
+ "name": "stderr",
560
+ "text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
561
+ },
562
+ {
563
+ "output_type": "display_data",
564
+ "data": {
565
+ "text/plain": "<IPython.core.display.HTML object>",
566
+ "text/html": "\n <div>\n \n <progress value='2907' max='14889' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 2907/14889 31:21 < 2:09:21, 1.54 it/s, Epoch 0.59/3]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
567
+ },
568
+ "metadata": {}
569
+ },
570
+ {
571
+ "output_type": "stream",
572
+ "name": "stderr",
573
+ "text": "[codecarbon INFO @ 21:21:26] Energy consumed for RAM : 0.000690 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:21:26] Energy consumed for all GPUs : 0.001498 kWh. Total GPU Power : 359.01546838188807 W\n[codecarbon INFO @ 21:21:26] Energy consumed for all CPUs : 0.000177 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:21:26] 0.002365 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:21:41] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:21:41] Energy consumed for all GPUs : 0.004065 kWh. Total GPU Power : 616.9146793770267 W\n[codecarbon INFO @ 21:21:41] Energy consumed for all CPUs : 0.000354 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:21:41] 0.005797 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:21:56] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:21:56] Energy consumed for all GPUs : 0.006654 kWh. Total GPU Power : 621.6877665436252 W\n[codecarbon INFO @ 21:21:56] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:21:56] 0.009251 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:11] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:11] Energy consumed for all GPUs : 0.009260 kWh. Total GPU Power : 626.1437572465749 W\n[codecarbon INFO @ 21:22:11] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:11] 0.012723 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:26] Energy consumed for RAM : 0.003443 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:26] Energy consumed for all GPUs : 0.011865 kWh. Total GPU Power : 625.3693802936192 W\n[codecarbon INFO @ 21:22:26] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:26] 0.016193 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:41] Energy consumed for RAM : 0.004131 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:41] Energy consumed for all GPUs : 0.014488 kWh. Total GPU Power : 630.2419235639226 W\n[codecarbon INFO @ 21:22:41] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:41] 0.019682 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:56] Energy consumed for RAM : 0.004819 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:56] Energy consumed for all GPUs : 0.017135 kWh. Total GPU Power : 635.8556506868297 W\n[codecarbon INFO @ 21:22:56] Energy consumed for all CPUs : 0.001240 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:56] 0.023194 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:11] Energy consumed for RAM : 0.005507 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:11] Energy consumed for all GPUs : 0.019738 kWh. Total GPU Power : 625.0758518089303 W\n[codecarbon INFO @ 21:23:11] Energy consumed for all CPUs : 0.001417 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:11] 0.026662 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:26] Energy consumed for RAM : 0.006195 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:26] Energy consumed for all GPUs : 0.022385 kWh. Total GPU Power : 636.0572579593729 W\n[codecarbon INFO @ 21:23:26] Energy consumed for all CPUs : 0.001594 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:26] 0.030175 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:41] Energy consumed for RAM : 0.006883 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:41] Energy consumed for all GPUs : 0.025031 kWh. Total GPU Power : 635.4132918961806 W\n[codecarbon INFO @ 21:23:41] Energy consumed for all CPUs : 0.001771 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:41] 0.033685 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:56] Energy consumed for RAM : 0.007572 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:56] Energy consumed for all GPUs : 0.027661 kWh. Total GPU Power : 631.8222916777424 W\n[codecarbon INFO @ 21:23:56] Energy consumed for all CPUs : 0.001948 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:56] 0.037180 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:11] Energy consumed for RAM : 0.008260 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:11] Energy consumed for all GPUs : 0.030315 kWh. Total GPU Power : 637.844758085687 W\n[codecarbon INFO @ 21:24:11] Energy consumed for all CPUs : 0.002125 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:11] 0.040701 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:26] Energy consumed for RAM : 0.008948 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:26] Energy consumed for all GPUs : 0.032970 kWh. Total GPU Power : 637.7063667069607 W\n[codecarbon INFO @ 21:24:26] Energy consumed for all CPUs : 0.002302 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:26] 0.044220 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:41] Energy consumed for RAM : 0.009636 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:41] Energy consumed for all GPUs : 0.035629 kWh. Total GPU Power : 638.7595521159491 W\n[codecarbon INFO @ 21:24:41] Energy consumed for all CPUs : 0.002479 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:41] 0.047744 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:56] Energy consumed for RAM : 0.010324 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:56] Energy consumed for all GPUs : 0.038234 kWh. Total GPU Power : 626.0118880295652 W\n[codecarbon INFO @ 21:24:56] Energy consumed for all CPUs : 0.002657 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:56] 0.051214 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:11] Energy consumed for RAM : 0.011012 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:11] Energy consumed for all GPUs : 0.040892 kWh. Total GPU Power : 638.4170631771941 W\n[codecarbon INFO @ 21:25:11] Energy consumed for all CPUs : 0.002834 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:11] 0.054738 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:26] Energy consumed for RAM : 0.011700 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:26] Energy consumed for all GPUs : 0.043524 kWh. Total GPU Power : 632.34394576946 W\n[codecarbon INFO @ 21:25:26] Energy consumed for all CPUs : 0.003011 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:26] 0.058235 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:41] Energy consumed for RAM : 0.012388 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:41] Energy consumed for all GPUs : 0.046182 kWh. Total GPU Power : 638.4662389546352 W\n[codecarbon INFO @ 21:25:41] Energy consumed for all CPUs : 0.003188 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:41] 0.061758 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:56] Energy consumed for RAM : 0.013076 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:56] Energy consumed for all GPUs : 0.048838 kWh. Total GPU Power : 638.0871021853263 W\n[codecarbon INFO @ 21:25:56] Energy consumed for all CPUs : 0.003365 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:56] 0.065279 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:11] Energy consumed for RAM : 0.013765 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:11] Energy consumed for all GPUs : 0.051499 kWh. Total GPU Power : 639.0983849678707 W\n[codecarbon INFO @ 21:26:11] Energy consumed for all CPUs : 0.003542 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:11] 0.068806 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:26] Energy consumed for RAM : 0.014453 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:26] Energy consumed for all GPUs : 0.054132 kWh. Total GPU Power : 632.549674567773 W\n[codecarbon INFO @ 21:26:26] Energy consumed for all CPUs : 0.003719 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:26] 0.072304 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:41] Energy consumed for RAM : 0.015141 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:41] Energy consumed for all GPUs : 0.056768 kWh. Total GPU Power : 633.3096652159345 W\n[codecarbon INFO @ 21:26:41] Energy consumed for all CPUs : 0.003896 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:41] 0.075804 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:56] Energy consumed for RAM : 0.015829 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:56] Energy consumed for all GPUs : 0.059404 kWh. Total GPU Power : 633.339846576491 W\n[codecarbon INFO @ 21:26:56] Energy consumed for all CPUs : 0.004073 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:56] 0.079306 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:11] Energy consumed for RAM : 0.016517 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:11] Energy consumed for all GPUs : 0.062068 kWh. Total GPU Power : 639.9100492849137 W\n[codecarbon INFO @ 21:27:11] Energy consumed for all CPUs : 0.004250 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:11] 0.082835 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:26] Energy consumed for RAM : 0.017205 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:26] Energy consumed for all GPUs : 0.064726 kWh. Total GPU Power : 638.6437092393893 W\n[codecarbon INFO @ 21:27:26] Energy consumed for all CPUs : 0.004427 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:26] 0.086359 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:41] Energy consumed for RAM : 0.017893 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:41] Energy consumed for all GPUs : 0.067388 kWh. Total GPU Power : 639.3487979354586 W\n[codecarbon INFO @ 21:27:41] Energy consumed for all CPUs : 0.004604 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:41] 0.089885 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:56] Energy consumed for RAM : 0.018581 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:56] Energy consumed for all GPUs : 0.070026 kWh. Total GPU Power : 633.6884387646057 W\n[codecarbon INFO @ 21:27:56] Energy consumed for all CPUs : 0.004781 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:56] 0.093389 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:11] Energy consumed for RAM : 0.019269 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:11] Energy consumed for all GPUs : 0.072687 kWh. Total GPU Power : 639.4422525221754 W\n[codecarbon INFO @ 21:28:11] Energy consumed for all CPUs : 0.004958 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:11] 0.096915 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:26] Energy consumed for RAM : 0.019958 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:26] Energy consumed for all GPUs : 0.075296 kWh. Total GPU Power : 626.9464464111006 W\n[codecarbon INFO @ 21:28:26] Energy consumed for all CPUs : 0.005135 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:26] 0.100390 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:41] Energy consumed for RAM : 0.020646 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:41] Energy consumed for all GPUs : 0.077962 kWh. Total GPU Power : 640.3962575270206 W\n[codecarbon INFO @ 21:28:41] Energy consumed for all CPUs : 0.005313 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:41] 0.103921 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:56] Energy consumed for RAM : 0.021334 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:56] Energy consumed for all GPUs : 0.080619 kWh. Total GPU Power : 638.3087387539953 W\n[codecarbon INFO @ 21:28:56] Energy consumed for all CPUs : 0.005490 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:56] 0.107443 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:11] Energy consumed for RAM : 0.022022 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:11] Energy consumed for all GPUs : 0.083270 kWh. Total GPU Power : 636.8708359764104 W\n[codecarbon INFO @ 21:29:11] Energy consumed for all CPUs : 0.005667 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:11] 0.110959 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:26] Energy consumed for RAM : 0.022710 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:26] Energy consumed for all GPUs : 0.085893 kWh. Total GPU Power : 630.0796388169725 W\n[codecarbon INFO @ 21:29:26] Energy consumed for all CPUs : 0.005844 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:26] 0.114447 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:41] Energy consumed for RAM : 0.023398 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:41] Energy consumed for all GPUs : 0.088548 kWh. Total GPU Power : 637.7758378447022 W\n[codecarbon INFO @ 21:29:41] Energy consumed for all CPUs : 0.006021 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:41] 0.117968 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:56] Energy consumed for RAM : 0.024087 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:56] Energy consumed for all GPUs : 0.091193 kWh. Total GPU Power : 634.8550146720521 W\n[codecarbon INFO @ 21:29:56] Energy consumed for all CPUs : 0.006198 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:56] 0.121478 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:11] Energy consumed for RAM : 0.024775 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:11] Energy consumed for all GPUs : 0.093817 kWh. Total GPU Power : 630.5186457226341 W\n[codecarbon INFO @ 21:30:11] Energy consumed for all CPUs : 0.006375 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:11] 0.124967 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:26] Energy consumed for RAM : 0.025463 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:26] Energy consumed for all GPUs : 0.096471 kWh. Total GPU Power : 637.5849420686613 W\n[codecarbon INFO @ 21:30:26] Energy consumed for all CPUs : 0.006552 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:26] 0.128486 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:41] Energy consumed for RAM : 0.026151 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:41] Energy consumed for all GPUs : 0.099109 kWh. Total GPU Power : 633.6189362439791 W\n[codecarbon INFO @ 21:30:41] Energy consumed for all CPUs : 0.006729 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:41] 0.131990 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:56] Energy consumed for RAM : 0.026839 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:56] Energy consumed for all GPUs : 0.101776 kWh. Total GPU Power : 640.6257944471723 W\n[codecarbon INFO @ 21:30:56] Energy consumed for all CPUs : 0.006906 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:56] 0.135522 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:11] Energy consumed for RAM : 0.027528 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:11] Energy consumed for all GPUs : 0.104439 kWh. Total GPU Power : 639.6643020904513 W\n[codecarbon INFO @ 21:31:11] Energy consumed for all CPUs : 0.007083 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:11] 0.139050 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:26] Energy consumed for RAM : 0.028216 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:26] Energy consumed for all GPUs : 0.107104 kWh. Total GPU Power : 640.0939172348444 W\n[codecarbon INFO @ 21:31:26] Energy consumed for all CPUs : 0.007260 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:26] 0.142580 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:41] Energy consumed for RAM : 0.028904 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:41] Energy consumed for all GPUs : 0.109747 kWh. Total GPU Power : 634.9935430223439 W\n[codecarbon INFO @ 21:31:41] Energy consumed for all CPUs : 0.007438 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:41] 0.146088 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:56] Energy consumed for RAM : 0.029592 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:56] Energy consumed for all GPUs : 0.112378 kWh. Total GPU Power : 632.1091666419065 W\n[codecarbon INFO @ 21:31:56] Energy consumed for all CPUs : 0.007615 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:56] 0.149585 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:11] Energy consumed for RAM : 0.030281 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:11] Energy consumed for all GPUs : 0.115023 kWh. Total GPU Power : 633.8442942682154 W\n[codecarbon INFO @ 21:32:11] Energy consumed for all CPUs : 0.007792 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:11] 0.153096 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:26] Energy consumed for RAM : 0.030968 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:26] Energy consumed for all GPUs : 0.117663 kWh. Total GPU Power : 635.0998743598051 W\n[codecarbon INFO @ 21:32:26] Energy consumed for all CPUs : 0.007969 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:26] 0.156599 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:41] Energy consumed for RAM : 0.031656 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:41] Energy consumed for all GPUs : 0.120332 kWh. Total GPU Power : 641.1040333084177 W\n[codecarbon INFO @ 21:32:41] Energy consumed for all CPUs : 0.008146 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:41] 0.160134 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:56] Energy consumed for RAM : 0.032344 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:56] Energy consumed for all GPUs : 0.122995 kWh. Total GPU Power : 639.7462391288783 W\n[codecarbon INFO @ 21:32:56] Energy consumed for all CPUs : 0.008323 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:56] 0.163662 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:11] Energy consumed for RAM : 0.033033 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:11] Energy consumed for all GPUs : 0.125648 kWh. Total GPU Power : 637.3370633888644 W\n[codecarbon INFO @ 21:33:11] Energy consumed for all CPUs : 0.008500 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:11] 0.167180 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:26] Energy consumed for RAM : 0.033721 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:26] Energy consumed for all GPUs : 0.128260 kWh. Total GPU Power : 627.497349520734 W\n[codecarbon INFO @ 21:33:26] Energy consumed for all CPUs : 0.008677 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:26] 0.170658 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:41] Energy consumed for RAM : 0.034409 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:41] Energy consumed for all GPUs : 0.130922 kWh. Total GPU Power : 639.378459827986 W\n[codecarbon INFO @ 21:33:41] Energy consumed for all CPUs : 0.008854 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:41] 0.174185 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:56] Energy consumed for RAM : 0.035097 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:56] Energy consumed for all GPUs : 0.133563 kWh. Total GPU Power : 634.2779963263187 W\n[codecarbon INFO @ 21:33:56] Energy consumed for all CPUs : 0.009031 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:56] 0.177692 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:11] Energy consumed for RAM : 0.035785 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:11] Energy consumed for all GPUs : 0.136211 kWh. Total GPU Power : 636.088462236655 W\n[codecarbon INFO @ 21:34:11] Energy consumed for all CPUs : 0.009208 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:11] 0.181205 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:26] Energy consumed for RAM : 0.036474 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:26] Energy consumed for all GPUs : 0.138875 kWh. Total GPU Power : 639.8420566736949 W\n[codecarbon INFO @ 21:34:26] Energy consumed for all CPUs : 0.009385 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:26] 0.184734 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:41] Energy consumed for RAM : 0.037162 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:41] Energy consumed for all GPUs : 0.141537 kWh. Total GPU Power : 639.4628459940732 W\n[codecarbon INFO @ 21:34:41] Energy consumed for all CPUs : 0.009563 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:41] 0.188261 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:56] Energy consumed for RAM : 0.037850 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:56] Energy consumed for all GPUs : 0.144193 kWh. Total GPU Power : 638.284138549091 W\n[codecarbon INFO @ 21:34:56] Energy consumed for all CPUs : 0.009740 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:56] 0.191782 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:11] Energy consumed for RAM : 0.038538 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:11] Energy consumed for all GPUs : 0.146807 kWh. Total GPU Power : 627.9721129851367 W\n[codecarbon INFO @ 21:35:11] Energy consumed for all CPUs : 0.009917 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:11] 0.195262 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:26] Energy consumed for RAM : 0.039226 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:26] Energy consumed for all GPUs : 0.149465 kWh. Total GPU Power : 638.5284782703005 W\n[codecarbon INFO @ 21:35:26] Energy consumed for all CPUs : 0.010094 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:26] 0.198785 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:41] Energy consumed for RAM : 0.039914 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:41] Energy consumed for all GPUs : 0.152101 kWh. Total GPU Power : 633.1180716439897 W\n[codecarbon INFO @ 21:35:41] Energy consumed for all CPUs : 0.010271 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:41] 0.202286 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:56] Energy consumed for RAM : 0.040602 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:56] Energy consumed for all GPUs : 0.154764 kWh. Total GPU Power : 640.0670545574203 W\n[codecarbon INFO @ 21:35:56] Energy consumed for all CPUs : 0.010448 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:56] 0.205814 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:11] Energy consumed for RAM : 0.041290 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:11] Energy consumed for all GPUs : 0.157432 kWh. Total GPU Power : 640.6751187111053 W\n[codecarbon INFO @ 21:36:11] Energy consumed for all CPUs : 0.010625 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:11] 0.209347 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:26] Energy consumed for RAM : 0.041979 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:26] Energy consumed for all GPUs : 0.160090 kWh. Total GPU Power : 638.5720494734854 W\n[codecarbon INFO @ 21:36:26] Energy consumed for all CPUs : 0.010802 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:26] 0.212871 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:41] Energy consumed for RAM : 0.042667 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:41] Energy consumed for all GPUs : 0.162737 kWh. Total GPU Power : 635.8084991485674 W\n[codecarbon INFO @ 21:36:41] Energy consumed for all CPUs : 0.010979 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:41] 0.216383 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:56] Energy consumed for RAM : 0.043355 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:56] Energy consumed for all GPUs : 0.165376 kWh. Total GPU Power : 634.1987011824029 W\n[codecarbon INFO @ 21:36:56] Energy consumed for all CPUs : 0.011156 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:56] 0.219886 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:11] Energy consumed for RAM : 0.044043 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:11] Energy consumed for all GPUs : 0.168015 kWh. Total GPU Power : 633.9887371766706 W\n[codecarbon INFO @ 21:37:11] Energy consumed for all CPUs : 0.011333 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:11] 0.223391 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:26] Energy consumed for RAM : 0.044731 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:26] Energy consumed for all GPUs : 0.170672 kWh. Total GPU Power : 638.6399487093975 W\n[codecarbon INFO @ 21:37:26] Energy consumed for all CPUs : 0.011510 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:26] 0.226913 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:41] Energy consumed for RAM : 0.045419 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:41] Energy consumed for all GPUs : 0.173330 kWh. Total GPU Power : 638.3717543169629 W\n[codecarbon INFO @ 21:37:41] Energy consumed for all CPUs : 0.011687 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:41] 0.230436 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:56] Energy consumed for RAM : 0.046107 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:56] Energy consumed for all GPUs : 0.175996 kWh. Total GPU Power : 640.4666251525215 W\n[codecarbon INFO @ 21:37:56] Energy consumed for all CPUs : 0.011864 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:56] 0.233967 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:11] Energy consumed for RAM : 0.046795 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:11] Energy consumed for all GPUs : 0.178655 kWh. Total GPU Power : 638.5808546734917 W\n[codecarbon INFO @ 21:38:11] Energy consumed for all CPUs : 0.012042 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:11] 0.237491 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:26] Energy consumed for RAM : 0.047483 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:26] Energy consumed for all GPUs : 0.181299 kWh. Total GPU Power : 635.2413118760992 W\n[codecarbon INFO @ 21:38:26] Energy consumed for all CPUs : 0.012219 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:26] 0.241001 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:41] Energy consumed for RAM : 0.048172 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:41] Energy consumed for all GPUs : 0.183936 kWh. Total GPU Power : 633.4649546198842 W\n[codecarbon INFO @ 21:38:41] Energy consumed for all CPUs : 0.012396 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:41] 0.244503 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:56] Energy consumed for RAM : 0.048860 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:56] Energy consumed for all GPUs : 0.186584 kWh. Total GPU Power : 636.131524414156 W\n[codecarbon INFO @ 21:38:56] Energy consumed for all CPUs : 0.012573 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:56] 0.248017 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:11] Energy consumed for RAM : 0.049548 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:11] Energy consumed for all GPUs : 0.189253 kWh. Total GPU Power : 641.2634282249857 W\n[codecarbon INFO @ 21:39:11] Energy consumed for all CPUs : 0.012750 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:11] 0.251551 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:26] Energy consumed for RAM : 0.050236 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:26] Energy consumed for all GPUs : 0.191926 kWh. Total GPU Power : 642.0441434380534 W\n[codecarbon INFO @ 21:39:26] Energy consumed for all CPUs : 0.012927 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:26] 0.255089 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:41] Energy consumed for RAM : 0.050924 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:41] Energy consumed for all GPUs : 0.194582 kWh. Total GPU Power : 637.9781331087586 W\n[codecarbon INFO @ 21:39:41] Energy consumed for all CPUs : 0.013104 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:41] 0.258610 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:56] Energy consumed for RAM : 0.051612 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:56] Energy consumed for all GPUs : 0.197230 kWh. Total GPU Power : 636.2697706727595 W\n[codecarbon INFO @ 21:39:56] Energy consumed for all CPUs : 0.013281 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:56] 0.262124 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:11] Energy consumed for RAM : 0.052300 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:11] Energy consumed for all GPUs : 0.199895 kWh. Total GPU Power : 640.3428775768339 W\n[codecarbon INFO @ 21:40:11] Energy consumed for all CPUs : 0.013458 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:11] 0.265654 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:26] Energy consumed for RAM : 0.052988 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:26] Energy consumed for all GPUs : 0.202515 kWh. Total GPU Power : 629.2766685535789 W\n[codecarbon INFO @ 21:40:26] Energy consumed for all CPUs : 0.013635 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:26] 0.269139 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:41] Energy consumed for RAM : 0.053677 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:41] Energy consumed for all GPUs : 0.205189 kWh. Total GPU Power : 642.1835694357527 W\n[codecarbon INFO @ 21:40:41] Energy consumed for all CPUs : 0.013812 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:41] 0.272678 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:56] Energy consumed for RAM : 0.054365 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:56] Energy consumed for all GPUs : 0.207849 kWh. Total GPU Power : 639.0512123347582 W\n[codecarbon INFO @ 21:40:56] Energy consumed for all CPUs : 0.013989 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:56] 0.276203 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:11] Energy consumed for RAM : 0.055053 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:11] Energy consumed for all GPUs : 0.210510 kWh. Total GPU Power : 639.2114833450486 W\n[codecarbon INFO @ 21:41:11] Energy consumed for all CPUs : 0.014166 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:11] 0.279729 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:26] Energy consumed for RAM : 0.055741 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:26] Energy consumed for all GPUs : 0.213150 kWh. Total GPU Power : 634.2127304940207 W\n[codecarbon INFO @ 21:41:26] Energy consumed for all CPUs : 0.014343 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:26] 0.283234 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:41] Energy consumed for RAM : 0.056429 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:41] Energy consumed for all GPUs : 0.215819 kWh. Total GPU Power : 641.2007745785662 W\n[codecarbon INFO @ 21:41:41] Energy consumed for all CPUs : 0.014521 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:41] 0.286769 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:56] Energy consumed for RAM : 0.057117 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:56] Energy consumed for all GPUs : 0.218486 kWh. Total GPU Power : 640.7227371900796 W\n[codecarbon INFO @ 21:41:56] Energy consumed for all CPUs : 0.014698 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:56] 0.290301 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:11] Energy consumed for RAM : 0.057806 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:11] Energy consumed for all GPUs : 0.221103 kWh. Total GPU Power : 628.6339716630874 W\n[codecarbon INFO @ 21:42:11] Energy consumed for all CPUs : 0.014875 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:11] 0.293783 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:26] Energy consumed for RAM : 0.058494 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:26] Energy consumed for all GPUs : 0.223767 kWh. Total GPU Power : 640.0746364068535 W\n[codecarbon INFO @ 21:42:26] Energy consumed for all CPUs : 0.015052 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:26] 0.297313 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:41] Energy consumed for RAM : 0.059182 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:41] Energy consumed for all GPUs : 0.226424 kWh. Total GPU Power : 638.2080824809826 W\n[codecarbon INFO @ 21:42:41] Energy consumed for all CPUs : 0.015229 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:41] 0.300835 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:56] Energy consumed for RAM : 0.059870 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:56] Energy consumed for all GPUs : 0.229064 kWh. Total GPU Power : 634.2533218929578 W\n[codecarbon INFO @ 21:42:56] Energy consumed for all CPUs : 0.015406 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:56] 0.304340 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:11] Energy consumed for RAM : 0.060558 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:11] Energy consumed for all GPUs : 0.231699 kWh. Total GPU Power : 632.8743496381484 W\n[codecarbon INFO @ 21:43:11] Energy consumed for all CPUs : 0.015583 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:11] 0.307840 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:26] Energy consumed for RAM : 0.061246 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:26] Energy consumed for all GPUs : 0.234369 kWh. Total GPU Power : 641.6283834036132 W\n[codecarbon INFO @ 21:43:26] Energy consumed for all CPUs : 0.015760 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:26] 0.311375 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:41] Energy consumed for RAM : 0.061934 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:41] Energy consumed for all GPUs : 0.237043 kWh. Total GPU Power : 642.3040649089497 W\n[codecarbon INFO @ 21:43:41] Energy consumed for all CPUs : 0.015937 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:41] 0.314915 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:56] Energy consumed for RAM : 0.062622 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:56] Energy consumed for all GPUs : 0.239655 kWh. Total GPU Power : 627.6126294000912 W\n[codecarbon INFO @ 21:43:56] Energy consumed for all CPUs : 0.016114 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:56] 0.318392 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:11] Energy consumed for RAM : 0.063311 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:11] Energy consumed for all GPUs : 0.242321 kWh. Total GPU Power : 640.4353576994267 W\n[codecarbon INFO @ 21:44:11] Energy consumed for all CPUs : 0.016291 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:11] 0.321923 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:26] Energy consumed for RAM : 0.063999 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:26] Energy consumed for all GPUs : 0.244988 kWh. Total GPU Power : 640.5277401792778 W\n[codecarbon INFO @ 21:44:26] Energy consumed for all CPUs : 0.016468 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:26] 0.325455 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:41] Energy consumed for RAM : 0.064687 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:41] Energy consumed for all GPUs : 0.247629 kWh. Total GPU Power : 634.3819968519699 W\n[codecarbon INFO @ 21:44:41] Energy consumed for all CPUs : 0.016645 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:41] 0.328961 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:56] Energy consumed for RAM : 0.065375 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:56] Energy consumed for all GPUs : 0.250281 kWh. Total GPU Power : 637.265582807383 W\n[codecarbon INFO @ 21:44:56] Energy consumed for all CPUs : 0.016823 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:56] 0.332479 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:11] Energy consumed for RAM : 0.066063 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:11] Energy consumed for all GPUs : 0.252945 kWh. Total GPU Power : 639.7317988572786 W\n[codecarbon INFO @ 21:45:11] Energy consumed for all CPUs : 0.017000 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:11] 0.336008 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:26] Energy consumed for RAM : 0.066751 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:26] Energy consumed for all GPUs : 0.255612 kWh. Total GPU Power : 640.7293994008817 W\n[codecarbon INFO @ 21:45:26] Energy consumed for all CPUs : 0.017177 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:26] 0.339540 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:41] Energy consumed for RAM : 0.067439 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:41] Energy consumed for all GPUs : 0.258225 kWh. Total GPU Power : 627.831662994067 W\n[codecarbon INFO @ 21:45:41] Energy consumed for all CPUs : 0.017354 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:41] 0.343018 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:56] Energy consumed for RAM : 0.068128 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:56] Energy consumed for all GPUs : 0.260886 kWh. Total GPU Power : 639.0834373322126 W\n[codecarbon INFO @ 21:45:56] Energy consumed for all CPUs : 0.017531 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:56] 0.346544 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:11] Energy consumed for RAM : 0.068816 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:11] Energy consumed for all GPUs : 0.263551 kWh. Total GPU Power : 640.21811942804 W\n[codecarbon INFO @ 21:46:11] Energy consumed for all CPUs : 0.017708 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:11] 0.350075 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:26] Energy consumed for RAM : 0.069504 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:26] Energy consumed for all GPUs : 0.266199 kWh. Total GPU Power : 635.36554464275 W\n[codecarbon INFO @ 21:46:26] Energy consumed for all CPUs : 0.017885 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:26] 0.353588 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:41] Energy consumed for RAM : 0.070192 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:41] Energy consumed for all GPUs : 0.268846 kWh. Total GPU Power : 636.3615954525276 W\n[codecarbon INFO @ 21:46:41] Energy consumed for all CPUs : 0.018062 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:41] 0.357099 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:56] Energy consumed for RAM : 0.070880 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:56] Energy consumed for all GPUs : 0.271505 kWh. Total GPU Power : 638.7791497333527 W\n[codecarbon INFO @ 21:46:56] Energy consumed for all CPUs : 0.018239 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:56] 0.360624 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:11] Energy consumed for RAM : 0.071568 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:11] Energy consumed for all GPUs : 0.274180 kWh. Total GPU Power : 642.4466497637459 W\n[codecarbon INFO @ 21:47:11] Energy consumed for all CPUs : 0.018416 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:11] 0.364164 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:26] Energy consumed for RAM : 0.072256 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:26] Energy consumed for all GPUs : 0.276794 kWh. Total GPU Power : 628.3190508121727 W\n[codecarbon INFO @ 21:47:26] Energy consumed for all CPUs : 0.018593 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:26] 0.367644 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:41] Energy consumed for RAM : 0.072944 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:41] Energy consumed for all GPUs : 0.279466 kWh. Total GPU Power : 641.8670593361426 W\n[codecarbon INFO @ 21:47:41] Energy consumed for all CPUs : 0.018770 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:41] 0.371181 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:56] Energy consumed for RAM : 0.073632 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:56] Energy consumed for all GPUs : 0.282137 kWh. Total GPU Power : 641.4703729656464 W\n[codecarbon INFO @ 21:47:56] Energy consumed for all CPUs : 0.018947 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:56] 0.374716 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:11] Energy consumed for RAM : 0.074321 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:11] Energy consumed for all GPUs : 0.284781 kWh. Total GPU Power : 635.1853195157872 W\n[codecarbon INFO @ 21:48:11] Energy consumed for all CPUs : 0.019125 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:11] 0.378226 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:26] Energy consumed for RAM : 0.075008 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:26] Energy consumed for all GPUs : 0.287442 kWh. Total GPU Power : 639.5462106369428 W\n[codecarbon INFO @ 21:48:26] Energy consumed for all CPUs : 0.019302 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:26] 0.381752 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:41] Energy consumed for RAM : 0.075697 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:41] Energy consumed for all GPUs : 0.290110 kWh. Total GPU Power : 640.8566755972051 W\n[codecarbon INFO @ 21:48:41] Energy consumed for all CPUs : 0.019479 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:41] 0.385285 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:56] Energy consumed for RAM : 0.076385 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:56] Energy consumed for all GPUs : 0.292782 kWh. Total GPU Power : 641.7052864666981 W\n[codecarbon INFO @ 21:48:56] Energy consumed for all CPUs : 0.019656 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:56] 0.388822 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:11] Energy consumed for RAM : 0.077073 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:11] Energy consumed for all GPUs : 0.295396 kWh. Total GPU Power : 628.327852773132 W\n[codecarbon INFO @ 21:49:11] Energy consumed for all CPUs : 0.019833 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:11] 0.392302 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:26] Energy consumed for RAM : 0.077761 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:26] Energy consumed for all GPUs : 0.298062 kWh. Total GPU Power : 640.3407136160951 W\n[codecarbon INFO @ 21:49:26] Energy consumed for all CPUs : 0.020010 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:26] 0.395833 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:41] Energy consumed for RAM : 0.078448 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:41] Energy consumed for all GPUs : 0.300704 kWh. Total GPU Power : 635.323287223873 W\n[codecarbon INFO @ 21:49:41] Energy consumed for all CPUs : 0.020187 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:41] 0.399339 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:56] Energy consumed for RAM : 0.079137 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:56] Energy consumed for all GPUs : 0.303364 kWh. Total GPU Power : 638.8823021398234 W\n[codecarbon INFO @ 21:49:56] Energy consumed for all CPUs : 0.020364 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:56] 0.402864 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:11] Energy consumed for RAM : 0.079825 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:11] Energy consumed for all GPUs : 0.306026 kWh. Total GPU Power : 639.5449269632837 W\n[codecarbon INFO @ 21:50:11] Energy consumed for all CPUs : 0.020541 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:11] 0.406392 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:26] Energy consumed for RAM : 0.080513 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:26] Energy consumed for all GPUs : 0.308696 kWh. Total GPU Power : 641.2978321880182 W\n[codecarbon INFO @ 21:50:26] Energy consumed for all CPUs : 0.020718 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:26] 0.409927 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:41] Energy consumed for RAM : 0.081201 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:41] Energy consumed for all GPUs : 0.311344 kWh. Total GPU Power : 636.0606478210286 W\n[codecarbon INFO @ 21:50:41] Energy consumed for all CPUs : 0.020895 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:41] 0.413440 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:56] Energy consumed for RAM : 0.081889 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:56] Energy consumed for all GPUs : 0.313981 kWh. Total GPU Power : 633.4763683114048 W\n[codecarbon INFO @ 21:50:56] Energy consumed for all CPUs : 0.021072 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:56] 0.416942 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:11] Energy consumed for RAM : 0.082577 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:11] Energy consumed for all GPUs : 0.316621 kWh. Total GPU Power : 634.3093356260223 W\n[codecarbon INFO @ 21:51:11] Energy consumed for all CPUs : 0.021249 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:11] 0.420447 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:26] Energy consumed for RAM : 0.083265 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:26] Energy consumed for all GPUs : 0.319275 kWh. Total GPU Power : 637.9102256749104 W\n[codecarbon INFO @ 21:51:26] Energy consumed for all CPUs : 0.021426 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:26] 0.423967 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:41] Energy consumed for RAM : 0.083953 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:41] Energy consumed for all GPUs : 0.321942 kWh. Total GPU Power : 640.7082551761707 W\n[codecarbon INFO @ 21:51:41] Energy consumed for all CPUs : 0.021603 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:41] 0.427499 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:56] Energy consumed for RAM : 0.084642 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:56] Energy consumed for all GPUs : 0.324609 kWh. Total GPU Power : 640.5186958631616 W\n[codecarbon INFO @ 21:51:56] Energy consumed for all CPUs : 0.021781 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:56] 0.431032 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:52:11] Energy consumed for RAM : 0.085330 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:52:11] Energy consumed for all GPUs : 0.327266 kWh. Total GPU Power : 638.3529279530002 W\n[codecarbon INFO @ 21:52:11] Energy consumed for all CPUs : 0.021958 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:52:11] 0.434553 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:52:26] Energy consumed for RAM : 0.086018 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:52:26] Energy consumed for all GPUs : 0.329910 kWh. Total GPU Power : 635.3762025182446 W\n[codecarbon INFO @ 21:52:26] Energy consumed for all CPUs : 0.022135 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:52:26] 0.438062 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:52:41] Energy consumed for RAM : 0.086706 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:52:41] Energy consumed for all GPUs : 0.332545 kWh. Total GPU Power : 632.9484218502305 W\n[codecarbon INFO @ 21:52:41] Energy consumed for all CPUs : 0.022312 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:52:41] 0.441563 kWh of electricity used since the beginning.\n"
574
+ }
575
+ ],
576
+ "execution_count": 16,
577
+ "metadata": {
578
+ "gather": {
579
+ "logged": 1706476228988
580
+ }
581
+ }
582
+ },
583
+ {
584
+ "cell_type": "code",
585
+ "source": [
586
+ "wandb.finish()"
587
+ ],
588
+ "outputs": [],
589
+ "execution_count": null,
590
+ "metadata": {
591
+ "gather": {
592
+ "logged": 1706476229030
593
+ }
594
+ }
595
+ },
596
+ {
597
+ "cell_type": "code",
598
+ "source": [
599
+ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
600
+ "tokenizer._tokenizer.save(\"tokenizer.json\")\n",
601
+ "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
602
+ "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
603
+ "\n",
604
+ "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
605
+ " variant=variant,\n",
606
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
607
+ ],
608
+ "outputs": [],
609
+ "execution_count": null,
610
+ "metadata": {
611
+ "gather": {
612
+ "logged": 1706476229038
613
+ }
614
+ }
615
+ }
616
+ ],
617
+ "metadata": {
618
+ "datalore": {
619
+ "base_environment": "default",
620
+ "computation_mode": "JUPYTER",
621
+ "package_manager": "pip",
622
+ "packages": [
623
+ {
624
+ "name": "datasets",
625
+ "source": "PIP",
626
+ "version": "2.16.1"
627
+ },
628
+ {
629
+ "name": "torch",
630
+ "source": "PIP",
631
+ "version": "2.1.2"
632
+ },
633
+ {
634
+ "name": "accelerate",
635
+ "source": "PIP",
636
+ "version": "0.26.1"
637
+ }
638
+ ],
639
+ "report_row_ids": [
640
+ "un8W7ez7ZwoGb5Co6nydEV",
641
+ "40nN9Hvgi1clHNV5RAemI5",
642
+ "TgRD90H5NSPpKS41OeXI1w",
643
+ "ZOm5BfUs3h1EGLaUkBGeEB",
644
+ "kOP0CZWNSk6vqE3wkPp7Vc",
645
+ "W4PWcOu2O2pRaZyoE2W80h",
646
+ "RolbOnQLIftk0vy9mIcz5M",
647
+ "8OPhUgbaNJmOdiq5D3a6vK",
648
+ "5Qrt3jSvSrpK6Ne1hS6shL",
649
+ "hTq7nFUrovN5Ao4u6dIYWZ",
650
+ "I8WNZLpJ1DVP2wiCW7YBIB",
651
+ "SawhU3I9BewSE1XBPstpNJ",
652
+ "80EtLEl2FIE4FqbWnUD3nT"
653
+ ],
654
+ "version": 3
655
+ },
656
+ "kernelspec": {
657
+ "name": "python38-azureml-pt-tf",
658
+ "language": "python",
659
+ "display_name": "Python 3.8 - Pytorch and Tensorflow"
660
+ },
661
+ "language_info": {
662
+ "name": "python",
663
+ "version": "3.8.5",
664
+ "mimetype": "text/x-python",
665
+ "codemirror_mode": {
666
+ "name": "ipython",
667
+ "version": 3
668
+ },
669
+ "pygments_lexer": "ipython3",
670
+ "nbconvert_exporter": "python",
671
+ "file_extension": ".py"
672
+ },
673
+ "microsoft": {
674
+ "host": {
675
+ "AzureML": {
676
+ "notebookHasBeenCompleted": true
677
+ }
678
+ },
679
+ "ms_spell_check": {
680
+ "ms_spell_check_language": "en"
681
+ }
682
+ },
683
+ "nteract": {
684
+ "version": "nteract-front-end@1.0.0"
685
+ },
686
+ "kernel_info": {
687
+ "name": "python38-azureml-pt-tf"
688
+ }
689
+ },
690
+ "nbformat": 4,
691
+ "nbformat_minor": 4
692
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-3-12-1Z.ipynb ADDED
@@ -0,0 +1,1053 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {
11
+ "collapsed": false
12
+ }
13
+ },
14
+ {
15
+ "cell_type": "code",
16
+ "source": [
17
+ "%pip install accelerate -U"
18
+ ],
19
+ "outputs": [
20
+ {
21
+ "output_type": "stream",
22
+ "name": "stdout",
23
+ "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nNote: you may need to restart the kernel to use updated packages.\n"
24
+ }
25
+ ],
26
+ "execution_count": 1,
27
+ "metadata": {
28
+ "jupyter": {
29
+ "source_hidden": false,
30
+ "outputs_hidden": false
31
+ },
32
+ "nteract": {
33
+ "transient": {
34
+ "deleting": false
35
+ }
36
+ }
37
+ }
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "source": [
42
+ "%pip install transformers datasets shap watermark wandb"
43
+ ],
44
+ "outputs": [
45
+ {
46
+ "output_type": "stream",
47
+ "name": "stderr",
48
+ "text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
49
+ },
50
+ {
51
+ "output_type": "stream",
52
+ "name": "stdout",
53
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nCollecting wandb\n Using cached wandb-0.16.2-py3-none-any.whl (2.2 MB)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nCollecting sentry-sdk>=1.0.0\n Using cached sentry_sdk-1.39.2-py2.py3-none-any.whl (254 kB)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nCollecting docker-pycreds>=0.4.0\n Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nCollecting setproctitle\n Using cached setproctitle-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\nCollecting appdirs>=1.4.3\n Using cached appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nInstalling collected packages: appdirs, setproctitle, sentry-sdk, docker-pycreds, wandb\nSuccessfully installed appdirs-1.4.4 docker-pycreds-0.4.0 sentry-sdk-1.39.2 setproctitle-1.3.3 wandb-0.16.2\nNote: you may need to restart the kernel to use updated packages.\n"
54
+ }
55
+ ],
56
+ "execution_count": 17,
57
+ "metadata": {
58
+ "jupyter": {
59
+ "source_hidden": false,
60
+ "outputs_hidden": false
61
+ },
62
+ "nteract": {
63
+ "transient": {
64
+ "deleting": false
65
+ }
66
+ }
67
+ }
68
+ },
69
+ {
70
+ "cell_type": "code",
71
+ "source": [
72
+ "import pandas as pd\n",
73
+ "import numpy as np\n",
74
+ "import torch\n",
75
+ "import os\n",
76
+ "from typing import List\n",
77
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
78
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
79
+ "from datasets import load_dataset\n",
80
+ "import shap\n",
81
+ "\n",
82
+ "%load_ext watermark"
83
+ ],
84
+ "outputs": [
85
+ {
86
+ "output_type": "stream",
87
+ "name": "stderr",
88
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 02:27:28.730200: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 02:27:29.708865: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 02:27:29.708983: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 02:27:29.708996: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
89
+ }
90
+ ],
91
+ "execution_count": 3,
92
+ "metadata": {
93
+ "datalore": {
94
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
95
+ "type": "CODE",
96
+ "hide_input_from_viewers": false,
97
+ "hide_output_from_viewers": false,
98
+ "report_properties": {
99
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
100
+ }
101
+ },
102
+ "gather": {
103
+ "logged": 1706408851775
104
+ }
105
+ }
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "source": [
110
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
111
+ "\n",
112
+ "SEED: int = 42\n",
113
+ "\n",
114
+ "BATCH_SIZE: int = 8\n",
115
+ "EPOCHS: int = 1\n",
116
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
117
+ "\n",
118
+ "CLASS_NAMES: List[str] = [\"DIED\",\n",
119
+ " \"ER_VISIT\",\n",
120
+ " \"HOSPITAL\",\n",
121
+ " \"OFC_VISIT\",\n",
122
+ " \"X_STAY\",\n",
123
+ " \"DISABLE\",\n",
124
+ " \"D_PRESENTED\"]\n",
125
+ "\n",
126
+ "# WandB configuration\n",
127
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
128
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
129
+ ],
130
+ "outputs": [],
131
+ "execution_count": 4,
132
+ "metadata": {
133
+ "collapsed": false,
134
+ "gather": {
135
+ "logged": 1706408852045
136
+ }
137
+ }
138
+ },
139
+ {
140
+ "cell_type": "code",
141
+ "source": [
142
+ "%watermark --iversion"
143
+ ],
144
+ "outputs": [
145
+ {
146
+ "output_type": "stream",
147
+ "name": "stdout",
148
+ "text": "re : 2.2.1\nnumpy : 1.23.5\nlogging: 0.5.1.2\npandas : 2.0.2\ntorch : 1.12.0\nshap : 0.44.1\n\n"
149
+ }
150
+ ],
151
+ "execution_count": 5,
152
+ "metadata": {
153
+ "collapsed": false
154
+ }
155
+ },
156
+ {
157
+ "cell_type": "code",
158
+ "source": [
159
+ "!nvidia-smi"
160
+ ],
161
+ "outputs": [
162
+ {
163
+ "output_type": "stream",
164
+ "name": "stdout",
165
+ "text": "Sun Jan 28 02:27:31 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 28C P0 37W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 27C P0 36W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
166
+ }
167
+ ],
168
+ "execution_count": 6,
169
+ "metadata": {
170
+ "datalore": {
171
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
172
+ "type": "CODE",
173
+ "hide_input_from_viewers": true,
174
+ "hide_output_from_viewers": true
175
+ }
176
+ }
177
+ },
178
+ {
179
+ "attachments": {},
180
+ "cell_type": "markdown",
181
+ "source": [
182
+ "## Loading the data set"
183
+ ],
184
+ "metadata": {
185
+ "datalore": {
186
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
187
+ "type": "MD",
188
+ "hide_input_from_viewers": false,
189
+ "hide_output_from_viewers": false,
190
+ "report_properties": {
191
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
192
+ }
193
+ }
194
+ }
195
+ },
196
+ {
197
+ "cell_type": "code",
198
+ "source": [
199
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
200
+ ],
201
+ "outputs": [],
202
+ "execution_count": 7,
203
+ "metadata": {
204
+ "collapsed": false,
205
+ "gather": {
206
+ "logged": 1706408853264
207
+ }
208
+ }
209
+ },
210
+ {
211
+ "cell_type": "markdown",
212
+ "source": [
213
+ "### Tokenisation and encoding"
214
+ ],
215
+ "metadata": {
216
+ "collapsed": false
217
+ }
218
+ },
219
+ {
220
+ "cell_type": "code",
221
+ "source": [
222
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
223
+ ],
224
+ "outputs": [],
225
+ "execution_count": 8,
226
+ "metadata": {
227
+ "datalore": {
228
+ "node_id": "I7n646PIscsUZRoHu6m7zm",
229
+ "type": "CODE",
230
+ "hide_input_from_viewers": true,
231
+ "hide_output_from_viewers": true
232
+ },
233
+ "gather": {
234
+ "logged": 1706408853475
235
+ }
236
+ }
237
+ },
238
+ {
239
+ "cell_type": "code",
240
+ "source": [
241
+ "def tokenize_and_encode(examples):\n",
242
+ " return tokenizer(examples[\"text\"], truncation=True)"
243
+ ],
244
+ "outputs": [],
245
+ "execution_count": 9,
246
+ "metadata": {
247
+ "datalore": {
248
+ "node_id": "QBLOSI0yVIslV7v7qX9ZC3",
249
+ "type": "CODE",
250
+ "hide_input_from_viewers": true,
251
+ "hide_output_from_viewers": true
252
+ },
253
+ "gather": {
254
+ "logged": 1706408853684
255
+ }
256
+ }
257
+ },
258
+ {
259
+ "cell_type": "code",
260
+ "source": [
261
+ "cols = dataset[\"train\"].column_names\n",
262
+ "cols.remove(\"labels\")\n",
263
+ "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
264
+ ],
265
+ "outputs": [
266
+ {
267
+ "output_type": "stream",
268
+ "name": "stderr",
269
+ "text": "Map: 100%|██████████| 15786/15786 [00:01<00:00, 10990.82 examples/s]\n"
270
+ }
271
+ ],
272
+ "execution_count": 10,
273
+ "metadata": {
274
+ "datalore": {
275
+ "node_id": "slHeNysZOX9uWS9PB7jFDb",
276
+ "type": "CODE",
277
+ "hide_input_from_viewers": true,
278
+ "hide_output_from_viewers": true
279
+ },
280
+ "gather": {
281
+ "logged": 1706408854738
282
+ }
283
+ }
284
+ },
285
+ {
286
+ "cell_type": "markdown",
287
+ "source": [
288
+ "### Training"
289
+ ],
290
+ "metadata": {
291
+ "collapsed": false
292
+ }
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "source": [
297
+ "class MultiLabelTrainer(Trainer):\n",
298
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
299
+ " labels = inputs.pop(\"labels\")\n",
300
+ " outputs = model(**inputs)\n",
301
+ " logits = outputs.logits\n",
302
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
303
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
304
+ " labels.float().view(-1, self.model.config.num_labels))\n",
305
+ " return (loss, outputs) if return_outputs else loss"
306
+ ],
307
+ "outputs": [],
308
+ "execution_count": 11,
309
+ "metadata": {
310
+ "datalore": {
311
+ "node_id": "itXWkbDw9sqbkMuDP84QoT",
312
+ "type": "CODE",
313
+ "hide_input_from_viewers": true,
314
+ "hide_output_from_viewers": true
315
+ },
316
+ "gather": {
317
+ "logged": 1706408854925
318
+ }
319
+ }
320
+ },
321
+ {
322
+ "cell_type": "code",
323
+ "source": [
324
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
325
+ ],
326
+ "outputs": [
327
+ {
328
+ "output_type": "stream",
329
+ "name": "stderr",
330
+ "text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
331
+ }
332
+ ],
333
+ "execution_count": 12,
334
+ "metadata": {
335
+ "datalore": {
336
+ "node_id": "ZQU7aW6TV45VmhHOQRzcnF",
337
+ "type": "CODE",
338
+ "hide_input_from_viewers": true,
339
+ "hide_output_from_viewers": true
340
+ },
341
+ "gather": {
342
+ "logged": 1706408857008
343
+ }
344
+ }
345
+ },
346
+ {
347
+ "cell_type": "code",
348
+ "source": [
349
+ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
350
+ " y_pred = torch.from_numpy(y_pred)\n",
351
+ " y_true = torch.from_numpy(y_true)\n",
352
+ "\n",
353
+ " if sigmoid:\n",
354
+ " y_pred = y_pred.sigmoid()\n",
355
+ "\n",
356
+ " return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
357
+ ],
358
+ "outputs": [],
359
+ "execution_count": 13,
360
+ "metadata": {
361
+ "datalore": {
362
+ "node_id": "swhgyyyxoGL8HjnXJtMuSW",
363
+ "type": "CODE",
364
+ "hide_input_from_viewers": true,
365
+ "hide_output_from_viewers": true
366
+ },
367
+ "gather": {
368
+ "logged": 1706408857297
369
+ }
370
+ }
371
+ },
372
+ {
373
+ "cell_type": "code",
374
+ "source": [
375
+ "def compute_metrics(eval_pred):\n",
376
+ " predictions, labels = eval_pred\n",
377
+ " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
378
+ ],
379
+ "outputs": [],
380
+ "execution_count": 14,
381
+ "metadata": {
382
+ "datalore": {
383
+ "node_id": "1Uq3HtkaBxtHNAnSwit5cI",
384
+ "type": "CODE",
385
+ "hide_input_from_viewers": true,
386
+ "hide_output_from_viewers": true
387
+ },
388
+ "gather": {
389
+ "logged": 1706408857499
390
+ }
391
+ }
392
+ },
393
+ {
394
+ "cell_type": "code",
395
+ "source": [
396
+ "args = TrainingArguments(\n",
397
+ " output_dir=\"vaers\",\n",
398
+ " evaluation_strategy=\"epoch\",\n",
399
+ " learning_rate=2e-5,\n",
400
+ " per_device_train_batch_size=BATCH_SIZE,\n",
401
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
402
+ " num_train_epochs=EPOCHS,\n",
403
+ " weight_decay=.01,\n",
404
+ " report_to=[\"wandb\"]\n",
405
+ ")"
406
+ ],
407
+ "outputs": [],
408
+ "execution_count": 15,
409
+ "metadata": {
410
+ "datalore": {
411
+ "node_id": "1iPZOTKPwSkTgX5dORqT89",
412
+ "type": "CODE",
413
+ "hide_input_from_viewers": true,
414
+ "hide_output_from_viewers": true
415
+ },
416
+ "gather": {
417
+ "logged": 1706408857680
418
+ }
419
+ }
420
+ },
421
+ {
422
+ "cell_type": "code",
423
+ "source": [
424
+ "multi_label_trainer = MultiLabelTrainer(\n",
425
+ " model, \n",
426
+ " args, \n",
427
+ " train_dataset=ds_enc[\"train\"], \n",
428
+ " eval_dataset=ds_enc[\"test\"], \n",
429
+ " compute_metrics=compute_metrics, \n",
430
+ " tokenizer=tokenizer\n",
431
+ ")"
432
+ ],
433
+ "outputs": [
434
+ {
435
+ "output_type": "stream",
436
+ "name": "stderr",
437
+ "text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
438
+ }
439
+ ],
440
+ "execution_count": 18,
441
+ "metadata": {
442
+ "datalore": {
443
+ "node_id": "bnRkNvRYltLun6gCEgL7v0",
444
+ "type": "CODE",
445
+ "hide_input_from_viewers": true,
446
+ "hide_output_from_viewers": true
447
+ },
448
+ "gather": {
449
+ "logged": 1706408895305
450
+ }
451
+ }
452
+ },
453
+ {
454
+ "cell_type": "code",
455
+ "source": [
456
+ "multi_label_trainer.evaluate()"
457
+ ],
458
+ "outputs": [
459
+ {
460
+ "output_type": "display_data",
461
+ "data": {
462
+ "text/plain": "<IPython.core.display.HTML object>",
463
+ "text/html": "\n <div>\n \n <progress value='1974' max='987' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [987/987 21:41]\n </div>\n "
464
+ },
465
+ "metadata": {}
466
+ },
467
+ {
468
+ "output_type": "stream",
469
+ "name": "stderr",
470
+ "text": "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
471
+ },
472
+ {
473
+ "output_type": "display_data",
474
+ "data": {
475
+ "text/plain": "<IPython.core.display.HTML object>",
476
+ "text/html": "Tracking run with wandb version 0.16.2"
477
+ },
478
+ "metadata": {}
479
+ },
480
+ {
481
+ "output_type": "display_data",
482
+ "data": {
483
+ "text/plain": "<IPython.core.display.HTML object>",
484
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_022947-hh1sxw9i</code>"
485
+ },
486
+ "metadata": {}
487
+ },
488
+ {
489
+ "output_type": "display_data",
490
+ "data": {
491
+ "text/plain": "<IPython.core.display.HTML object>",
492
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i' target=\"_blank\">icy-firebrand-1</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
493
+ },
494
+ "metadata": {}
495
+ },
496
+ {
497
+ "output_type": "display_data",
498
+ "data": {
499
+ "text/plain": "<IPython.core.display.HTML object>",
500
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
501
+ },
502
+ "metadata": {}
503
+ },
504
+ {
505
+ "output_type": "display_data",
506
+ "data": {
507
+ "text/plain": "<IPython.core.display.HTML object>",
508
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i</a>"
509
+ },
510
+ "metadata": {}
511
+ },
512
+ {
513
+ "output_type": "execute_result",
514
+ "execution_count": 19,
515
+ "data": {
516
+ "text/plain": "{'eval_loss': 0.7153111100196838,\n 'eval_accuracy_thresh': 0.2938227355480194,\n 'eval_runtime': 82.3613,\n 'eval_samples_per_second': 191.668,\n 'eval_steps_per_second': 11.984}"
517
+ },
518
+ "metadata": {}
519
+ }
520
+ ],
521
+ "execution_count": 19,
522
+ "metadata": {
523
+ "datalore": {
524
+ "node_id": "LO54PlDkWQdFrzV25FvduB",
525
+ "type": "CODE",
526
+ "hide_input_from_viewers": true,
527
+ "hide_output_from_viewers": true
528
+ },
529
+ "gather": {
530
+ "logged": 1706408991752
531
+ }
532
+ }
533
+ },
534
+ {
535
+ "cell_type": "code",
536
+ "source": [
537
+ "multi_label_trainer.train()"
538
+ ],
539
+ "outputs": [
540
+ {
541
+ "output_type": "display_data",
542
+ "data": {
543
+ "text/plain": "<IPython.core.display.HTML object>",
544
+ "text/html": "\n <div>\n \n <progress value='4605' max='4605' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [4605/4605 20:25, Epoch 1/1]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n <th>Accuracy Thresh</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>1</td>\n <td>0.086700</td>\n <td>0.093388</td>\n <td>0.962897</td>\n </tr>\n </tbody>\n</table><p>"
545
+ },
546
+ "metadata": {}
547
+ },
548
+ {
549
+ "output_type": "stream",
550
+ "name": "stderr",
551
+ "text": "Checkpoint destination directory vaers/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.9s\nCheckpoint destination directory vaers/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 12.5s\nCheckpoint destination directory vaers/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 21.9s\nCheckpoint destination directory vaers/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 13.8s\nCheckpoint destination directory vaers/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 15.7s\nCheckpoint destination directory vaers/checkpoint-3000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 21.7s\nCheckpoint destination directory vaers/checkpoint-3500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 10.6s\nCheckpoint destination directory vaers/checkpoint-4000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4000)... Done. 15.0s\nCheckpoint destination directory vaers/checkpoint-4500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4500)... Done. 16.7s\n"
552
+ },
553
+ {
554
+ "output_type": "execute_result",
555
+ "execution_count": 21,
556
+ "data": {
557
+ "text/plain": "TrainOutput(global_step=4605, training_loss=0.09062977189220382, metrics={'train_runtime': 1223.2444, 'train_samples_per_second': 60.223, 'train_steps_per_second': 3.765, 'total_flos': 9346797199425174.0, 'train_loss': 0.09062977189220382, 'epoch': 1.0})"
558
+ },
559
+ "metadata": {}
560
+ }
561
+ ],
562
+ "execution_count": 21,
563
+ "metadata": {
564
+ "datalore": {
565
+ "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
566
+ "type": "CODE",
567
+ "hide_input_from_viewers": true,
568
+ "hide_output_from_viewers": true
569
+ },
570
+ "gather": {
571
+ "logged": 1706411445752
572
+ }
573
+ }
574
+ },
575
+ {
576
+ "cell_type": "markdown",
577
+ "source": [
578
+ "### Evaluation"
579
+ ],
580
+ "metadata": {
581
+ "collapsed": false
582
+ }
583
+ },
584
+ {
585
+ "cell_type": "markdown",
586
+ "source": [
587
+ "We instantiate a classifier `pipeline` and push it to CUDA."
588
+ ],
589
+ "metadata": {
590
+ "collapsed": false
591
+ }
592
+ },
593
+ {
594
+ "cell_type": "code",
595
+ "source": [
596
+ "classifier = pipeline(\"text-classification\", \n",
597
+ " model, \n",
598
+ " tokenizer=tokenizer, \n",
599
+ " device=\"cuda:0\")"
600
+ ],
601
+ "outputs": [],
602
+ "execution_count": 24,
603
+ "metadata": {
604
+ "datalore": {
605
+ "node_id": "kHoUdBeqcyVXDSGv54C4aE",
606
+ "type": "CODE",
607
+ "hide_input_from_viewers": true,
608
+ "hide_output_from_viewers": true
609
+ },
610
+ "gather": {
611
+ "logged": 1706411459928
612
+ }
613
+ }
614
+ },
615
+ {
616
+ "cell_type": "markdown",
617
+ "source": [
618
+ "We use the same tokenizer used for training to tokenize/encode the validation set."
619
+ ],
620
+ "metadata": {
621
+ "collapsed": false
622
+ }
623
+ },
624
+ {
625
+ "cell_type": "code",
626
+ "source": [
627
+ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
628
+ " max_length=255, \n",
629
+ " pad_to_max_length=True, \n",
630
+ " return_token_type_ids=True, \n",
631
+ " truncation=True)"
632
+ ],
633
+ "outputs": [
634
+ {
635
+ "output_type": "error",
636
+ "ename": "KeyError",
637
+ "evalue": "'validate'",
638
+ "traceback": [
639
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
640
+ "\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
641
+ "Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m test_encodings \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_encode_plus(\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvalidate\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m], \n\u001b[1;32m 2\u001b[0m max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m255\u001b[39m, \n\u001b[1;32m 3\u001b[0m pad_to_max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \n\u001b[1;32m 4\u001b[0m return_token_type_ids\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \n\u001b[1;32m 5\u001b[0m truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
642
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/datasets/dataset_dict.py:74\u001b[0m, in \u001b[0;36mDatasetDict.__getitem__\u001b[0;34m(self, k)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, k) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dataset:\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(k, (\u001b[38;5;28mstr\u001b[39m, NamedSplit)) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 76\u001b[0m available_suggested_splits \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 77\u001b[0m split \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m (Split\u001b[38;5;241m.\u001b[39mTRAIN, Split\u001b[38;5;241m.\u001b[39mTEST, Split\u001b[38;5;241m.\u001b[39mVALIDATION) \u001b[38;5;28;01mif\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 78\u001b[0m ]\n",
643
+ "\u001b[0;31mKeyError\u001b[0m: 'validate'"
644
+ ]
645
+ }
646
+ ],
647
+ "execution_count": 25,
648
+ "metadata": {
649
+ "datalore": {
650
+ "node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
651
+ "type": "CODE",
652
+ "hide_input_from_viewers": true,
653
+ "hide_output_from_viewers": true
654
+ },
655
+ "gather": {
656
+ "logged": 1706411465538
657
+ }
658
+ }
659
+ },
660
+ {
661
+ "cell_type": "markdown",
662
+ "source": [
663
+ "Once we've made the data loadable by putting it into a `DataLoader`, we "
664
+ ],
665
+ "metadata": {
666
+ "collapsed": false
667
+ }
668
+ },
669
+ {
670
+ "cell_type": "code",
671
+ "source": [
672
+ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
673
+ " torch.tensor(test_encodings['attention_mask']), \n",
674
+ " torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n",
675
+ " torch.tensor(test_encodings['token_type_ids']))\n",
676
+ "test_dataloader = torch.utils.data.DataLoader(test_data, \n",
677
+ " sampler=torch.utils.data.SequentialSampler(test_data), \n",
678
+ " batch_size=BATCH_SIZE)"
679
+ ],
680
+ "outputs": [],
681
+ "execution_count": null,
682
+ "metadata": {
683
+ "datalore": {
684
+ "node_id": "MWfGq2tTkJNzFiDoUPq2X7",
685
+ "type": "CODE",
686
+ "hide_input_from_viewers": true,
687
+ "hide_output_from_viewers": true
688
+ },
689
+ "gather": {
690
+ "logged": 1706411446707
691
+ }
692
+ }
693
+ },
694
+ {
695
+ "cell_type": "code",
696
+ "source": [
697
+ "model.eval()\n",
698
+ "\n",
699
+ "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
700
+ "\n",
701
+ "for i, batch in enumerate(test_dataloader):\n",
702
+ " batch = tuple(t.to(device) for t in batch)\n",
703
+ " # Unpack the inputs from our dataloader\n",
704
+ " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
705
+ " \n",
706
+ " with torch.no_grad():\n",
707
+ " outs = model(b_input_ids, attention_mask=b_input_mask)\n",
708
+ " b_logit_pred = outs[0]\n",
709
+ " pred_label = torch.sigmoid(b_logit_pred)\n",
710
+ "\n",
711
+ " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
712
+ " pred_label = pred_label.to('cpu').numpy()\n",
713
+ " b_labels = b_labels.to('cpu').numpy()\n",
714
+ "\n",
715
+ " tokenized_texts.append(b_input_ids)\n",
716
+ " logit_preds.append(b_logit_pred)\n",
717
+ " true_labels.append(b_labels)\n",
718
+ " pred_labels.append(pred_label)\n",
719
+ "\n",
720
+ "# Flatten outputs\n",
721
+ "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
722
+ "pred_labels = [item for sublist in pred_labels for item in sublist]\n",
723
+ "true_labels = [item for sublist in true_labels for item in sublist]\n",
724
+ "\n",
725
+ "# Converting flattened binary values to boolean values\n",
726
+ "true_bools = [tl == 1 for tl in true_labels]\n",
727
+ "pred_bools = [pl > 0.50 for pl in pred_labels] "
728
+ ],
729
+ "outputs": [],
730
+ "execution_count": null,
731
+ "metadata": {
732
+ "datalore": {
733
+ "node_id": "1SJCSrQTRCexFCNCIyRrzL",
734
+ "type": "CODE",
735
+ "hide_input_from_viewers": true,
736
+ "hide_output_from_viewers": true
737
+ },
738
+ "gather": {
739
+ "logged": 1706411446723
740
+ }
741
+ }
742
+ },
743
+ {
744
+ "cell_type": "markdown",
745
+ "source": [
746
+ "We create a classification report:"
747
+ ],
748
+ "metadata": {
749
+ "collapsed": false
750
+ }
751
+ },
752
+ {
753
+ "cell_type": "code",
754
+ "source": [
755
+ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
756
+ "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
757
+ "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
758
+ "print(clf_report)"
759
+ ],
760
+ "outputs": [],
761
+ "execution_count": null,
762
+ "metadata": {
763
+ "datalore": {
764
+ "node_id": "eBprrgF086mznPbPVBpOLS",
765
+ "type": "CODE",
766
+ "hide_input_from_viewers": true,
767
+ "hide_output_from_viewers": true
768
+ },
769
+ "gather": {
770
+ "logged": 1706411446746
771
+ }
772
+ }
773
+ },
774
+ {
775
+ "cell_type": "markdown",
776
+ "source": [
777
+ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
778
+ ],
779
+ "metadata": {
780
+ "collapsed": false
781
+ }
782
+ },
783
+ {
784
+ "cell_type": "code",
785
+ "source": [
786
+ "# Creating a map of class names from class numbers\n",
787
+ "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
788
+ ],
789
+ "outputs": [],
790
+ "execution_count": null,
791
+ "metadata": {
792
+ "datalore": {
793
+ "node_id": "yELHY0IEwMlMw3x6e7hoD1",
794
+ "type": "CODE",
795
+ "hide_input_from_viewers": true,
796
+ "hide_output_from_viewers": true
797
+ },
798
+ "gather": {
799
+ "logged": 1706411446758
800
+ }
801
+ }
802
+ },
803
+ {
804
+ "cell_type": "code",
805
+ "source": [
806
+ "true_label_idxs, pred_label_idxs = [], []\n",
807
+ "\n",
808
+ "for vals in true_bools:\n",
809
+ " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
810
+ "for vals in pred_bools:\n",
811
+ " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
812
+ ],
813
+ "outputs": [],
814
+ "execution_count": null,
815
+ "metadata": {
816
+ "datalore": {
817
+ "node_id": "jH0S35dDteUch01sa6me6e",
818
+ "type": "CODE",
819
+ "hide_input_from_viewers": true,
820
+ "hide_output_from_viewers": true
821
+ },
822
+ "gather": {
823
+ "logged": 1706411446771
824
+ }
825
+ }
826
+ },
827
+ {
828
+ "cell_type": "code",
829
+ "source": [
830
+ "true_label_texts, pred_label_texts = [], []\n",
831
+ "\n",
832
+ "for vals in true_label_idxs:\n",
833
+ " if vals:\n",
834
+ " true_label_texts.append([idx2label[val] for val in vals])\n",
835
+ " else:\n",
836
+ " true_label_texts.append(vals)\n",
837
+ "\n",
838
+ "for vals in pred_label_idxs:\n",
839
+ " if vals:\n",
840
+ " pred_label_texts.append([idx2label[val] for val in vals])\n",
841
+ " else:\n",
842
+ " pred_label_texts.append(vals)"
843
+ ],
844
+ "outputs": [],
845
+ "execution_count": null,
846
+ "metadata": {
847
+ "datalore": {
848
+ "node_id": "h4vHL8XdGpayZ6xLGJUF6F",
849
+ "type": "CODE",
850
+ "hide_input_from_viewers": true,
851
+ "hide_output_from_viewers": true
852
+ },
853
+ "gather": {
854
+ "logged": 1706411446785
855
+ }
856
+ }
857
+ },
858
+ {
859
+ "cell_type": "code",
860
+ "source": [
861
+ "symptom_texts = [tokenizer.decode(text,\n",
862
+ " skip_special_tokens=True,\n",
863
+ " clean_up_tokenization_spaces=False) for text in tokenized_texts]"
864
+ ],
865
+ "outputs": [],
866
+ "execution_count": null,
867
+ "metadata": {
868
+ "datalore": {
869
+ "node_id": "SxUmVHfQISEeptg1SawOmB",
870
+ "type": "CODE",
871
+ "hide_input_from_viewers": true,
872
+ "hide_output_from_viewers": true
873
+ },
874
+ "gather": {
875
+ "logged": 1706411446805
876
+ }
877
+ }
878
+ },
879
+ {
880
+ "cell_type": "code",
881
+ "source": [
882
+ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
883
+ " 'true_labels': true_label_texts, \n",
884
+ " 'pred_labels':pred_label_texts})\n",
885
+ "comparisons_df.to_csv('comparisons.csv')\n",
886
+ "comparisons_df"
887
+ ],
888
+ "outputs": [],
889
+ "execution_count": null,
890
+ "metadata": {
891
+ "datalore": {
892
+ "node_id": "BxFNigNGRLTOqraI55BPSH",
893
+ "type": "CODE",
894
+ "hide_input_from_viewers": true,
895
+ "hide_output_from_viewers": true
896
+ },
897
+ "gather": {
898
+ "logged": 1706411446818
899
+ }
900
+ }
901
+ },
902
+ {
903
+ "cell_type": "markdown",
904
+ "source": [
905
+ "### Shapley analysis"
906
+ ],
907
+ "metadata": {
908
+ "collapsed": false
909
+ }
910
+ },
911
+ {
912
+ "cell_type": "code",
913
+ "source": [
914
+ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
915
+ ],
916
+ "outputs": [],
917
+ "execution_count": null,
918
+ "metadata": {
919
+ "datalore": {
920
+ "node_id": "OpdZcoenX2HwzLdai7K5UA",
921
+ "type": "CODE",
922
+ "hide_input_from_viewers": true,
923
+ "hide_output_from_viewers": true
924
+ },
925
+ "gather": {
926
+ "logged": 1706411446829
927
+ }
928
+ }
929
+ },
930
+ {
931
+ "cell_type": "code",
932
+ "source": [
933
+ "shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])"
934
+ ],
935
+ "outputs": [],
936
+ "execution_count": null,
937
+ "metadata": {
938
+ "datalore": {
939
+ "node_id": "FvbCMfIDlcf16YSvb8wNQv",
940
+ "type": "CODE",
941
+ "hide_input_from_viewers": true,
942
+ "hide_output_from_viewers": true
943
+ },
944
+ "gather": {
945
+ "logged": 1706411446839
946
+ }
947
+ }
948
+ },
949
+ {
950
+ "cell_type": "code",
951
+ "source": [
952
+ "shap.plots.text(shap_values)"
953
+ ],
954
+ "outputs": [],
955
+ "execution_count": null,
956
+ "metadata": {
957
+ "datalore": {
958
+ "node_id": "TSxvakWLPCpjVMWi9ZdEbd",
959
+ "type": "CODE",
960
+ "hide_input_from_viewers": true,
961
+ "hide_output_from_viewers": true
962
+ },
963
+ "gather": {
964
+ "logged": 1706411446848
965
+ }
966
+ }
967
+ },
968
+ {
969
+ "cell_type": "code",
970
+ "source": [],
971
+ "outputs": [],
972
+ "execution_count": null,
973
+ "metadata": {
974
+ "jupyter": {
975
+ "source_hidden": false,
976
+ "outputs_hidden": false
977
+ },
978
+ "nteract": {
979
+ "transient": {
980
+ "deleting": false
981
+ }
982
+ }
983
+ }
984
+ }
985
+ ],
986
+ "metadata": {
987
+ "kernelspec": {
988
+ "name": "python3",
989
+ "language": "python",
990
+ "display_name": "Python 3 (ipykernel)"
991
+ },
992
+ "datalore": {
993
+ "computation_mode": "JUPYTER",
994
+ "package_manager": "pip",
995
+ "base_environment": "default",
996
+ "packages": [
997
+ {
998
+ "name": "datasets",
999
+ "version": "2.16.1",
1000
+ "source": "PIP"
1001
+ },
1002
+ {
1003
+ "name": "torch",
1004
+ "version": "2.1.2",
1005
+ "source": "PIP"
1006
+ },
1007
+ {
1008
+ "name": "accelerate",
1009
+ "version": "0.26.1",
1010
+ "source": "PIP"
1011
+ }
1012
+ ],
1013
+ "report_row_ids": [
1014
+ "un8W7ez7ZwoGb5Co6nydEV",
1015
+ "40nN9Hvgi1clHNV5RAemI5",
1016
+ "TgRD90H5NSPpKS41OeXI1w",
1017
+ "ZOm5BfUs3h1EGLaUkBGeEB",
1018
+ "kOP0CZWNSk6vqE3wkPp7Vc",
1019
+ "W4PWcOu2O2pRaZyoE2W80h",
1020
+ "RolbOnQLIftk0vy9mIcz5M",
1021
+ "8OPhUgbaNJmOdiq5D3a6vK",
1022
+ "5Qrt3jSvSrpK6Ne1hS6shL",
1023
+ "hTq7nFUrovN5Ao4u6dIYWZ",
1024
+ "I8WNZLpJ1DVP2wiCW7YBIB",
1025
+ "SawhU3I9BewSE1XBPstpNJ",
1026
+ "80EtLEl2FIE4FqbWnUD3nT"
1027
+ ],
1028
+ "version": 3
1029
+ },
1030
+ "microsoft": {
1031
+ "ms_spell_check": {
1032
+ "ms_spell_check_language": "en"
1033
+ }
1034
+ },
1035
+ "language_info": {
1036
+ "name": "python",
1037
+ "version": "3.8.5",
1038
+ "mimetype": "text/x-python",
1039
+ "codemirror_mode": {
1040
+ "name": "ipython",
1041
+ "version": 3
1042
+ },
1043
+ "pygments_lexer": "ipython3",
1044
+ "nbconvert_exporter": "python",
1045
+ "file_extension": ".py"
1046
+ },
1047
+ "nteract": {
1048
+ "version": "nteract-front-end@1.0.0"
1049
+ }
1050
+ },
1051
+ "nbformat": 4,
1052
+ "nbformat_minor": 4
1053
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-4-13-53Z.ipynb ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-14-26-30Z.ipynb ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {}
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "source": [
15
+ "%pip install accelerate -U"
16
+ ],
17
+ "outputs": [
18
+ {
19
+ "output_type": "stream",
20
+ "name": "stdout",
21
+ "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nNote: you may need to restart the kernel to use updated packages.\n"
22
+ }
23
+ ],
24
+ "execution_count": 1,
25
+ "metadata": {
26
+ "gather": {
27
+ "logged": 1706475754655
28
+ },
29
+ "nteract": {
30
+ "transient": {
31
+ "deleting": false
32
+ }
33
+ },
34
+ "tags": []
35
+ }
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "source": [
40
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
41
+ ],
42
+ "outputs": [
43
+ {
44
+ "output_type": "stream",
45
+ "name": "stdout",
46
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
47
+ }
48
+ ],
49
+ "execution_count": 2,
50
+ "metadata": {
51
+ "nteract": {
52
+ "transient": {
53
+ "deleting": false
54
+ }
55
+ }
56
+ }
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "import pandas as pd\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "import os\n",
65
+ "from typing import List, Union\n",
66
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
67
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
68
+ "import shap\n",
69
+ "import wandb\n",
70
+ "import evaluate\n",
71
+ "from codecarbon import EmissionsTracker\n",
72
+ "import logging\n",
73
+ "\n",
74
+ "wandb.finish()\n",
75
+ "\n",
76
+ "logging.getLogger('codecarbon').propagate = False\n",
77
+ "\n",
78
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
79
+ "tracker = EmissionsTracker()\n",
80
+ "\n",
81
+ "%load_ext watermark"
82
+ ],
83
+ "outputs": [
84
+ {
85
+ "output_type": "stream",
86
+ "name": "stderr",
87
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 04:43:58.191236: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 04:43:59.182154: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 04:43:59.182291: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 04:43:59.182304: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n[codecarbon INFO @ 04:44:02] [setup] RAM Tracking...\n[codecarbon INFO @ 04:44:02] [setup] GPU Tracking...\n[codecarbon INFO @ 04:44:02] Tracking Nvidia GPU via pynvml\n[codecarbon INFO @ 04:44:02] [setup] CPU Tracking...\n[codecarbon WARNING @ 04:44:02] No CPU tracking mode found. Falling back on CPU constant mode.\n[codecarbon WARNING @ 04:44:03] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n[codecarbon INFO @ 04:44:03] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 04:44:03] >>> Tracker's metadata:\n[codecarbon INFO @ 04:44:03] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n[codecarbon INFO @ 04:44:03] Python version: 3.8.5\n[codecarbon INFO @ 04:44:03] CodeCarbon version: 2.3.3\n[codecarbon INFO @ 04:44:03] Available RAM : 440.883 GB\n[codecarbon INFO @ 04:44:03] CPU count: 24\n[codecarbon INFO @ 04:44:03] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 04:44:03] GPU count: 4\n[codecarbon INFO @ 04:44:03] GPU model: 4 x Tesla V100-PCIE-16GB\n[codecarbon WARNING @ 04:44:03] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
88
+ }
89
+ ],
90
+ "execution_count": 3,
91
+ "metadata": {
92
+ "datalore": {
93
+ "hide_input_from_viewers": false,
94
+ "hide_output_from_viewers": false,
95
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
96
+ "report_properties": {
97
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
98
+ },
99
+ "type": "CODE"
100
+ },
101
+ "gather": {
102
+ "logged": 1706503443742
103
+ },
104
+ "tags": []
105
+ }
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "source": [
110
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
111
+ "\n",
112
+ "SEED: int = 42\n",
113
+ "\n",
114
+ "BATCH_SIZE: int = 32\n",
115
+ "EPOCHS: int = 5\n",
116
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
117
+ "\n",
118
+ "# WandB configuration\n",
119
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
120
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
121
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
122
+ ],
123
+ "outputs": [],
124
+ "execution_count": 4,
125
+ "metadata": {
126
+ "collapsed": false,
127
+ "gather": {
128
+ "logged": 1706503443899
129
+ },
130
+ "jupyter": {
131
+ "outputs_hidden": false
132
+ }
133
+ }
134
+ },
135
+ {
136
+ "cell_type": "code",
137
+ "source": [
138
+ "%watermark --iversion"
139
+ ],
140
+ "outputs": [
141
+ {
142
+ "output_type": "stream",
143
+ "name": "stdout",
144
+ "text": "shap : 0.44.1\nnumpy : 1.23.5\npandas : 2.0.2\nlogging : 0.5.1.2\ntorch : 1.12.0\nevaluate: 0.4.1\nwandb : 0.16.2\nre : 2.2.1\n\n"
145
+ }
146
+ ],
147
+ "execution_count": 5,
148
+ "metadata": {
149
+ "collapsed": false,
150
+ "jupyter": {
151
+ "outputs_hidden": false
152
+ }
153
+ }
154
+ },
155
+ {
156
+ "cell_type": "code",
157
+ "source": [
158
+ "!nvidia-smi"
159
+ ],
160
+ "outputs": [
161
+ {
162
+ "output_type": "stream",
163
+ "name": "stdout",
164
+ "text": "Mon Jan 29 04:44:03 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
165
+ }
166
+ ],
167
+ "execution_count": 6,
168
+ "metadata": {
169
+ "datalore": {
170
+ "hide_input_from_viewers": true,
171
+ "hide_output_from_viewers": true,
172
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
173
+ "type": "CODE"
174
+ }
175
+ }
176
+ },
177
+ {
178
+ "cell_type": "markdown",
179
+ "source": [
180
+ "## Loading the data set"
181
+ ],
182
+ "metadata": {
183
+ "datalore": {
184
+ "hide_input_from_viewers": false,
185
+ "hide_output_from_viewers": false,
186
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
187
+ "report_properties": {
188
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
189
+ },
190
+ "type": "MD"
191
+ }
192
+ }
193
+ },
194
+ {
195
+ "cell_type": "code",
196
+ "source": [
197
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
198
+ ],
199
+ "outputs": [],
200
+ "execution_count": 7,
201
+ "metadata": {
202
+ "collapsed": false,
203
+ "gather": {
204
+ "logged": 1706503446033
205
+ },
206
+ "jupyter": {
207
+ "outputs_hidden": false
208
+ }
209
+ }
210
+ },
211
+ {
212
+ "cell_type": "code",
213
+ "source": [
214
+ "dataset"
215
+ ],
216
+ "outputs": [
217
+ {
218
+ "output_type": "execute_result",
219
+ "execution_count": 8,
220
+ "data": {
221
+ "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
222
+ },
223
+ "metadata": {}
224
+ }
225
+ ],
226
+ "execution_count": 8,
227
+ "metadata": {
228
+ "collapsed": false,
229
+ "gather": {
230
+ "logged": 1706503446252
231
+ },
232
+ "jupyter": {
233
+ "outputs_hidden": false,
234
+ "source_hidden": false
235
+ },
236
+ "nteract": {
237
+ "transient": {
238
+ "deleting": false
239
+ }
240
+ }
241
+ }
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "source": [
246
+ "SUBSAMPLING = 1.0\n",
247
+ "\n",
248
+ "if SUBSAMPLING < 1:\n",
249
+ " _ = DatasetDict()\n",
250
+ " for each in dataset.keys():\n",
251
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
252
+ "\n",
253
+ " dataset = _"
254
+ ],
255
+ "outputs": [],
256
+ "execution_count": 9,
257
+ "metadata": {
258
+ "gather": {
259
+ "logged": 1706503446498
260
+ }
261
+ }
262
+ },
263
+ {
264
+ "cell_type": "markdown",
265
+ "source": [
266
+ "## Tokenisation and encoding"
267
+ ],
268
+ "metadata": {}
269
+ },
270
+ {
271
+ "cell_type": "code",
272
+ "source": [
273
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
274
+ " return ds_enc"
275
+ ],
276
+ "outputs": [],
277
+ "execution_count": 10,
278
+ "metadata": {
279
+ "gather": {
280
+ "logged": 1706503446633
281
+ }
282
+ }
283
+ },
284
+ {
285
+ "cell_type": "markdown",
286
+ "source": [
287
+ "## Evaluation metrics"
288
+ ],
289
+ "metadata": {}
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "source": [
294
+ "accuracy = evaluate.load(\"accuracy\")\n",
295
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
296
+ "f1 = evaluate.load(\"f1\")"
297
+ ],
298
+ "outputs": [],
299
+ "execution_count": 11,
300
+ "metadata": {
301
+ "gather": {
302
+ "logged": 1706503446863
303
+ }
304
+ }
305
+ },
306
+ {
307
+ "cell_type": "code",
308
+ "source": [
309
+ "def compute_metrics(eval_pred):\n",
310
+ " predictions, labels = eval_pred\n",
311
+ " predictions = np.argmax(predictions, axis=1)\n",
312
+ " return {\n",
313
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
314
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
315
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
316
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
317
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
318
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
319
+ " }"
320
+ ],
321
+ "outputs": [],
322
+ "execution_count": 12,
323
+ "metadata": {
324
+ "gather": {
325
+ "logged": 1706503447004
326
+ }
327
+ }
328
+ },
329
+ {
330
+ "cell_type": "markdown",
331
+ "source": [
332
+ "## Training"
333
+ ],
334
+ "metadata": {}
335
+ },
336
+ {
337
+ "cell_type": "markdown",
338
+ "source": [
339
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
340
+ ],
341
+ "metadata": {}
342
+ },
343
+ {
344
+ "cell_type": "code",
345
+ "source": [
346
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
347
+ ],
348
+ "outputs": [],
349
+ "execution_count": 13,
350
+ "metadata": {
351
+ "gather": {
352
+ "logged": 1706503447186
353
+ }
354
+ }
355
+ },
356
+ {
357
+ "cell_type": "code",
358
+ "source": [
359
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
360
+ "\n",
361
+ "cols = dataset[\"train\"].column_names\n",
362
+ "cols.remove(\"label\")\n",
363
+ "ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n"
364
+ ],
365
+ "outputs": [
366
+ {
367
+ "output_type": "stream",
368
+ "name": "stderr",
369
+ "text": "Map: 100%|██████████| 272238/272238 [01:45<00:00, 2592.04 examples/s]\n"
370
+ }
371
+ ],
372
+ "execution_count": 14,
373
+ "metadata": {
374
+ "gather": {
375
+ "logged": 1706503552083
376
+ }
377
+ }
378
+ },
379
+ {
380
+ "cell_type": "code",
381
+ "source": [
382
+ "\n",
383
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
384
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
385
+ " id2label=label_map, \n",
386
+ " label2id={v:k for k,v in label_map.items()})\n",
387
+ "\n",
388
+ "args = TrainingArguments(\n",
389
+ " output_dir=\"vaers\",\n",
390
+ " evaluation_strategy=\"epoch\",\n",
391
+ " save_strategy=\"epoch\",\n",
392
+ " learning_rate=2e-5,\n",
393
+ " per_device_train_batch_size=BATCH_SIZE,\n",
394
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
395
+ " num_train_epochs=EPOCHS,\n",
396
+ " weight_decay=.01,\n",
397
+ " logging_steps=1,\n",
398
+ " load_best_model_at_end=True,\n",
399
+ " run_name=f\"daedra-training\",\n",
400
+ " report_to=[\"wandb\"])\n",
401
+ "\n",
402
+ "trainer = Trainer(\n",
403
+ " model=model,\n",
404
+ " args=args,\n",
405
+ " train_dataset=ds_enc[\"train\"],\n",
406
+ " eval_dataset=ds_enc[\"test\"],\n",
407
+ " tokenizer=tokenizer,\n",
408
+ " compute_metrics=compute_metrics)"
409
+ ],
410
+ "outputs": [
411
+ {
412
+ "output_type": "stream",
413
+ "name": "stderr",
414
+ "text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
415
+ }
416
+ ],
417
+ "execution_count": 15,
418
+ "metadata": {
419
+ "gather": {
420
+ "logged": 1706503554669
421
+ }
422
+ }
423
+ },
424
+ {
425
+ "cell_type": "code",
426
+ "source": [
427
+ "if SUBSAMPLING != 1.0:\n",
428
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
429
+ "else:\n",
430
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
431
+ "\n",
432
+ "wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
433
+ "wandb_tag.append(f\"base:{model_ckpt}\")\n",
434
+ " \n",
435
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
436
+ ],
437
+ "outputs": [
438
+ {
439
+ "output_type": "stream",
440
+ "name": "stderr",
441
+ "text": "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
442
+ },
443
+ {
444
+ "output_type": "display_data",
445
+ "data": {
446
+ "text/plain": "<IPython.core.display.HTML object>",
447
+ "text/html": "Tracking run with wandb version 0.16.2"
448
+ },
449
+ "metadata": {}
450
+ },
451
+ {
452
+ "output_type": "display_data",
453
+ "data": {
454
+ "text/plain": "<IPython.core.display.HTML object>",
455
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_044555-kjhyoltp</code>"
456
+ },
457
+ "metadata": {}
458
+ },
459
+ {
460
+ "output_type": "display_data",
461
+ "data": {
462
+ "text/plain": "<IPython.core.display.HTML object>",
463
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
464
+ },
465
+ "metadata": {}
466
+ },
467
+ {
468
+ "output_type": "display_data",
469
+ "data": {
470
+ "text/plain": "<IPython.core.display.HTML object>",
471
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
472
+ },
473
+ "metadata": {}
474
+ },
475
+ {
476
+ "output_type": "display_data",
477
+ "data": {
478
+ "text/plain": "<IPython.core.display.HTML object>",
479
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp</a>"
480
+ },
481
+ "metadata": {}
482
+ },
483
+ {
484
+ "output_type": "display_data",
485
+ "data": {
486
+ "text/plain": "<IPython.core.display.HTML object>",
487
+ "text/html": "Finishing last run (ID:kjhyoltp) before initializing another..."
488
+ },
489
+ "metadata": {}
490
+ },
491
+ {
492
+ "output_type": "display_data",
493
+ "data": {
494
+ "text/plain": "<IPython.core.display.HTML object>",
495
+ "text/html": "W&B sync reduced upload amount by 26.5% "
496
+ },
497
+ "metadata": {}
498
+ },
499
+ {
500
+ "output_type": "display_data",
501
+ "data": {
502
+ "text/plain": "<IPython.core.display.HTML object>",
503
+ "text/html": " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v1' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v1</a><br/>Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
504
+ },
505
+ "metadata": {}
506
+ },
507
+ {
508
+ "output_type": "display_data",
509
+ "data": {
510
+ "text/plain": "<IPython.core.display.HTML object>",
511
+ "text/html": "Find logs at: <code>./wandb/run-20240129_044555-kjhyoltp/logs</code>"
512
+ },
513
+ "metadata": {}
514
+ },
515
+ {
516
+ "output_type": "display_data",
517
+ "data": {
518
+ "text/plain": "<IPython.core.display.HTML object>",
519
+ "text/html": "Successfully finished last run (ID:kjhyoltp). Initializing new run:<br/>"
520
+ },
521
+ "metadata": {}
522
+ },
523
+ {
524
+ "output_type": "display_data",
525
+ "data": {
526
+ "text/plain": "<IPython.core.display.HTML object>",
527
+ "text/html": "Tracking run with wandb version 0.16.2"
528
+ },
529
+ "metadata": {}
530
+ },
531
+ {
532
+ "output_type": "display_data",
533
+ "data": {
534
+ "text/plain": "<IPython.core.display.HTML object>",
535
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_044558-ed51hqn6</code>"
536
+ },
537
+ "metadata": {}
538
+ },
539
+ {
540
+ "output_type": "display_data",
541
+ "data": {
542
+ "text/plain": "<IPython.core.display.HTML object>",
543
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
544
+ },
545
+ "metadata": {}
546
+ },
547
+ {
548
+ "output_type": "display_data",
549
+ "data": {
550
+ "text/plain": "<IPython.core.display.HTML object>",
551
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
552
+ },
553
+ "metadata": {}
554
+ },
555
+ {
556
+ "output_type": "display_data",
557
+ "data": {
558
+ "text/plain": "<IPython.core.display.HTML object>",
559
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6</a>"
560
+ },
561
+ "metadata": {}
562
+ },
563
+ {
564
+ "output_type": "execute_result",
565
+ "execution_count": 16,
566
+ "data": {
567
+ "text/html": "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>",
568
+ "text/plain": "<wandb.sdk.wandb_run.Run at 0x7f19d09fdbe0>"
569
+ },
570
+ "metadata": {}
571
+ }
572
+ ],
573
+ "execution_count": 16,
574
+ "metadata": {
575
+ "gather": {
576
+ "logged": 1706503566090
577
+ }
578
+ }
579
+ },
580
+ {
581
+ "cell_type": "code",
582
+ "source": [
583
+ "tracker.start()\n",
584
+ "trainer.train()\n",
585
+ "tracker.stop()\n"
586
+ ],
587
+ "outputs": [
588
+ {
589
+ "output_type": "stream",
590
+ "name": "stderr",
591
+ "text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
592
+ },
593
+ {
594
+ "output_type": "display_data",
595
+ "data": {
596
+ "text/plain": "<IPython.core.display.HTML object>",
597
+ "text/html": "\n <div>\n \n <progress value='183' max='49630' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 183/49630 01:56 < 8:51:06, 1.55 it/s, Epoch 0.02/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
598
+ },
599
+ "metadata": {}
600
+ },
601
+ {
602
+ "output_type": "stream",
603
+ "name": "stderr",
604
+ "text": "[codecarbon INFO @ 04:46:20] Energy consumed for RAM : 0.000690 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:46:20] Energy consumed for all GPUs : 0.001499 kWh. Total GPU Power : 359.1829830586385 W\n[codecarbon INFO @ 04:46:20] Energy consumed for all CPUs : 0.000177 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:46:20] 0.002366 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:46:35] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:46:35] Energy consumed for all GPUs : 0.004078 kWh. Total GPU Power : 619.6193403526773 W\n[codecarbon INFO @ 04:46:35] Energy consumed for all CPUs : 0.000355 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:46:35] 0.005811 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:46:50] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:46:50] Energy consumed for all GPUs : 0.006632 kWh. Total GPU Power : 613.6554096062922 W\n[codecarbon INFO @ 04:46:50] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:46:50] 0.009230 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:05] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:05] Energy consumed for all GPUs : 0.009249 kWh. Total GPU Power : 628.5574609453653 W\n[codecarbon INFO @ 04:47:05] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:05] 0.012712 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:20] Energy consumed for RAM : 0.003442 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:20] Energy consumed for all GPUs : 0.011850 kWh. Total GPU Power : 624.8454173521444 W\n[codecarbon INFO @ 04:47:20] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:20] 0.016178 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:35] Energy consumed for RAM : 0.004130 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:35] Energy consumed for all GPUs : 0.014490 kWh. Total GPU Power : 634.7378588005432 W\n[codecarbon INFO @ 04:47:35] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:35] 0.019683 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:50] Energy consumed for RAM : 0.004818 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:50] Energy consumed for all GPUs : 0.017140 kWh. Total GPU Power : 636.6500188212152 W\n[codecarbon INFO @ 04:47:50] Energy consumed for all CPUs : 0.001240 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:50] 0.023197 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:48:05] Energy consumed for RAM : 0.005506 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:48:05] Energy consumed for all GPUs : 0.019771 kWh. Total GPU Power : 631.881788173399 W\n[codecarbon INFO @ 04:48:05] Energy consumed for all CPUs : 0.001417 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:48:05] 0.026694 kWh of electricity used since the beginning.\n"
605
+ }
606
+ ],
607
+ "execution_count": 17,
608
+ "metadata": {
609
+ "gather": {
610
+ "logged": 1706486541798
611
+ }
612
+ }
613
+ },
614
+ {
615
+ "cell_type": "code",
616
+ "source": [
617
+ "wandb.finish()"
618
+ ],
619
+ "outputs": [],
620
+ "execution_count": null,
621
+ "metadata": {
622
+ "gather": {
623
+ "logged": 1706486541918
624
+ }
625
+ }
626
+ },
627
+ {
628
+ "cell_type": "code",
629
+ "source": [
630
+ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
631
+ "tokenizer._tokenizer.save(\"tokenizer.json\")\n",
632
+ "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
633
+ "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
634
+ "\n",
635
+ "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
636
+ " variant=variant,\n",
637
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
638
+ ],
639
+ "outputs": [],
640
+ "execution_count": null,
641
+ "metadata": {
642
+ "gather": {
643
+ "logged": 1706486541928
644
+ }
645
+ }
646
+ },
647
+ {
648
+ "cell_type": "code",
649
+ "source": [
650
+ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
651
+ "tokenizer._tokenizer.save(\"tokenizer.json\")\n",
652
+ "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
653
+ "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
654
+ "\n",
655
+ "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
656
+ " variant=variant,\n",
657
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
658
+ ],
659
+ "outputs": [],
660
+ "execution_count": null,
661
+ "metadata": {}
662
+ }
663
+ ],
664
+ "metadata": {
665
+ "datalore": {
666
+ "base_environment": "default",
667
+ "computation_mode": "JUPYTER",
668
+ "package_manager": "pip",
669
+ "packages": [
670
+ {
671
+ "name": "datasets",
672
+ "source": "PIP",
673
+ "version": "2.16.1"
674
+ },
675
+ {
676
+ "name": "torch",
677
+ "source": "PIP",
678
+ "version": "2.1.2"
679
+ },
680
+ {
681
+ "name": "accelerate",
682
+ "source": "PIP",
683
+ "version": "0.26.1"
684
+ }
685
+ ],
686
+ "report_row_ids": [
687
+ "un8W7ez7ZwoGb5Co6nydEV",
688
+ "40nN9Hvgi1clHNV5RAemI5",
689
+ "TgRD90H5NSPpKS41OeXI1w",
690
+ "ZOm5BfUs3h1EGLaUkBGeEB",
691
+ "kOP0CZWNSk6vqE3wkPp7Vc",
692
+ "W4PWcOu2O2pRaZyoE2W80h",
693
+ "RolbOnQLIftk0vy9mIcz5M",
694
+ "8OPhUgbaNJmOdiq5D3a6vK",
695
+ "5Qrt3jSvSrpK6Ne1hS6shL",
696
+ "hTq7nFUrovN5Ao4u6dIYWZ",
697
+ "I8WNZLpJ1DVP2wiCW7YBIB",
698
+ "SawhU3I9BewSE1XBPstpNJ",
699
+ "80EtLEl2FIE4FqbWnUD3nT"
700
+ ],
701
+ "version": 3
702
+ },
703
+ "kernel_info": {
704
+ "name": "python38-azureml-pt-tf"
705
+ },
706
+ "kernelspec": {
707
+ "display_name": "azureml_py38_PT_TF",
708
+ "language": "python",
709
+ "name": "python3"
710
+ },
711
+ "language_info": {
712
+ "name": "python",
713
+ "version": "3.8.5",
714
+ "mimetype": "text/x-python",
715
+ "codemirror_mode": {
716
+ "name": "ipython",
717
+ "version": 3
718
+ },
719
+ "pygments_lexer": "ipython3",
720
+ "nbconvert_exporter": "python",
721
+ "file_extension": ".py"
722
+ },
723
+ "microsoft": {
724
+ "host": {
725
+ "AzureML": {
726
+ "notebookHasBeenCompleted": true
727
+ }
728
+ },
729
+ "ms_spell_check": {
730
+ "ms_spell_check_language": "en"
731
+ }
732
+ },
733
+ "nteract": {
734
+ "version": "nteract-front-end@1.0.0"
735
+ }
736
+ },
737
+ "nbformat": 4,
738
+ "nbformat_minor": 4
739
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-16-5-15Z.ipynb ADDED
@@ -0,0 +1,729 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
8
+ "\n",
9
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "gather": {
17
+ "logged": 1706475754655
18
+ },
19
+ "nteract": {
20
+ "transient": {
21
+ "deleting": false
22
+ }
23
+ },
24
+ "tags": []
25
+ },
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
32
+ "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
33
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
34
+ "Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
35
+ "Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
36
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
37
+ "Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
38
+ "Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
39
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
40
+ "Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
41
+ "Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
42
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
43
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
44
+ "Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
45
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
46
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
47
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
48
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
49
+ "Note: you may need to restart the kernel to use updated packages.\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "%pip install accelerate -U"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 2,
60
+ "metadata": {
61
+ "nteract": {
62
+ "transient": {
63
+ "deleting": false
64
+ }
65
+ }
66
+ },
67
+ "outputs": [
68
+ {
69
+ "name": "stdout",
70
+ "output_type": "stream",
71
+ "text": [
72
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
73
+ "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
74
+ "Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
75
+ "Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
76
+ "Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
77
+ "Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
78
+ "Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
79
+ "Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
80
+ "Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
81
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
82
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
83
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
84
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
85
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
86
+ "Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
87
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
88
+ "Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
89
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
90
+ "Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
91
+ "Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
92
+ "Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
93
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
94
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
95
+ "Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
96
+ "Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
97
+ "Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
98
+ "Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
99
+ "Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
100
+ "Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
101
+ "Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
102
+ "Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
103
+ "Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
104
+ "Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
105
+ "Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
106
+ "Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
107
+ "Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
108
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
109
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
110
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
111
+ "Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
112
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
113
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
114
+ "Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
115
+ "Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
116
+ "Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
117
+ "Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
118
+ "Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
119
+ "Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
120
+ "Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
121
+ "Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
122
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
123
+ "Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
124
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
125
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
126
+ "Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
127
+ "Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
128
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
129
+ "Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
130
+ "Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
131
+ "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
132
+ "Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
133
+ "Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
134
+ "Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
135
+ "Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
136
+ "Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
137
+ "Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
138
+ "Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
139
+ "Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
140
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
141
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
142
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
143
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
144
+ "Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
145
+ "Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
146
+ "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
147
+ "Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
148
+ "Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
149
+ "Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
150
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
151
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
152
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
153
+ "Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
154
+ "Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
155
+ "Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
156
+ "Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
157
+ "Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
158
+ "Note: you may need to restart the kernel to use updated packages.\n"
159
+ ]
160
+ }
161
+ ],
162
+ "source": [
163
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 28,
169
+ "metadata": {
170
+ "datalore": {
171
+ "hide_input_from_viewers": false,
172
+ "hide_output_from_viewers": false,
173
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
174
+ "report_properties": {
175
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
176
+ },
177
+ "type": "CODE"
178
+ },
179
+ "gather": {
180
+ "logged": 1706503443742
181
+ },
182
+ "tags": []
183
+ },
184
+ "outputs": [
185
+ {
186
+ "data": {
187
+ "text/html": [
188
+ " View run <strong style=\"color:#cdcd00\">daedra_0.05-distilbert-base-uncased</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
189
+ ],
190
+ "text/plain": [
191
+ "<IPython.core.display.HTML object>"
192
+ ]
193
+ },
194
+ "metadata": {},
195
+ "output_type": "display_data"
196
+ },
197
+ {
198
+ "data": {
199
+ "text/html": [
200
+ "Find logs at: <code>./wandb/run-20240129_152136-cwkdl3x7/logs</code>"
201
+ ],
202
+ "text/plain": [
203
+ "<IPython.core.display.HTML object>"
204
+ ]
205
+ },
206
+ "metadata": {},
207
+ "output_type": "display_data"
208
+ },
209
+ {
210
+ "name": "stdout",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "The watermark extension is already loaded. To reload it, use:\n",
214
+ " %reload_ext watermark\n"
215
+ ]
216
+ }
217
+ ],
218
+ "source": [
219
+ "import pandas as pd\n",
220
+ "import numpy as np\n",
221
+ "import torch\n",
222
+ "import os\n",
223
+ "from typing import List, Union\n",
224
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
225
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
226
+ "import shap\n",
227
+ "import wandb\n",
228
+ "import evaluate\n",
229
+ "import logging\n",
230
+ "\n",
231
+ "wandb.finish()\n",
232
+ "\n",
233
+ "\n",
234
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
235
+ "\n",
236
+ "%load_ext watermark"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 4,
242
+ "metadata": {
243
+ "collapsed": false,
244
+ "gather": {
245
+ "logged": 1706503443899
246
+ },
247
+ "jupyter": {
248
+ "outputs_hidden": false
249
+ }
250
+ },
251
+ "outputs": [],
252
+ "source": [
253
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
254
+ "\n",
255
+ "SEED: int = 42\n",
256
+ "\n",
257
+ "BATCH_SIZE: int = 32\n",
258
+ "EPOCHS: int = 5\n",
259
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
260
+ "\n",
261
+ "# WandB configuration\n",
262
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
263
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
264
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": 5,
270
+ "metadata": {
271
+ "collapsed": false,
272
+ "jupyter": {
273
+ "outputs_hidden": false
274
+ }
275
+ },
276
+ "outputs": [
277
+ {
278
+ "name": "stdout",
279
+ "output_type": "stream",
280
+ "text": [
281
+ "re : 2.2.1\n",
282
+ "torch : 1.12.0\n",
283
+ "wandb : 0.16.2\n",
284
+ "logging : 0.5.1.2\n",
285
+ "numpy : 1.23.5\n",
286
+ "pandas : 2.0.2\n",
287
+ "evaluate: 0.4.1\n",
288
+ "shap : 0.44.1\n",
289
+ "\n"
290
+ ]
291
+ }
292
+ ],
293
+ "source": [
294
+ "%watermark --iversion"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 6,
300
+ "metadata": {
301
+ "datalore": {
302
+ "hide_input_from_viewers": true,
303
+ "hide_output_from_viewers": true,
304
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
305
+ "type": "CODE"
306
+ }
307
+ },
308
+ "outputs": [
309
+ {
310
+ "name": "stdout",
311
+ "output_type": "stream",
312
+ "text": [
313
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
314
+ "Mon Jan 29 15:20:22 2024 \n",
315
+ "+---------------------------------------------------------------------------------------+\n",
316
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
317
+ "|-----------------------------------------+----------------------+----------------------+\n",
318
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
319
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
320
+ "| | | MIG M. |\n",
321
+ "|=========================================+======================+======================|\n",
322
+ "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
323
+ "| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
324
+ "| | | N/A |\n",
325
+ "+-----------------------------------------+----------------------+----------------------+\n",
326
+ "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
327
+ "| N/A 26C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\n",
328
+ "| | | N/A |\n",
329
+ "+-----------------------------------------+----------------------+----------------------+\n",
330
+ "| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\n",
331
+ "| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
332
+ "| | | N/A |\n",
333
+ "+-----------------------------------------+----------------------+----------------------+\n",
334
+ "| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\n",
335
+ "| N/A 28C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
336
+ "| | | N/A |\n",
337
+ "+-----------------------------------------+----------------------+----------------------+\n",
338
+ " \n",
339
+ "+---------------------------------------------------------------------------------------+\n",
340
+ "| Processes: |\n",
341
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
342
+ "| ID ID Usage |\n",
343
+ "|=======================================================================================|\n",
344
+ "| No running processes found |\n",
345
+ "+---------------------------------------------------------------------------------------+\n"
346
+ ]
347
+ }
348
+ ],
349
+ "source": [
350
+ "!nvidia-smi"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "markdown",
355
+ "metadata": {
356
+ "datalore": {
357
+ "hide_input_from_viewers": false,
358
+ "hide_output_from_viewers": false,
359
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
360
+ "report_properties": {
361
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
362
+ },
363
+ "type": "MD"
364
+ }
365
+ },
366
+ "source": [
367
+ "## Loading the data set"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 7,
373
+ "metadata": {
374
+ "collapsed": false,
375
+ "gather": {
376
+ "logged": 1706503446033
377
+ },
378
+ "jupyter": {
379
+ "outputs_hidden": false
380
+ }
381
+ },
382
+ "outputs": [],
383
+ "source": [
384
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 8,
390
+ "metadata": {
391
+ "collapsed": false,
392
+ "gather": {
393
+ "logged": 1706503446252
394
+ },
395
+ "jupyter": {
396
+ "outputs_hidden": false,
397
+ "source_hidden": false
398
+ },
399
+ "nteract": {
400
+ "transient": {
401
+ "deleting": false
402
+ }
403
+ }
404
+ },
405
+ "outputs": [
406
+ {
407
+ "data": {
408
+ "text/plain": [
409
+ "DatasetDict({\n",
410
+ " train: Dataset({\n",
411
+ " features: ['id', 'text', 'label'],\n",
412
+ " num_rows: 1270444\n",
413
+ " })\n",
414
+ " test: Dataset({\n",
415
+ " features: ['id', 'text', 'label'],\n",
416
+ " num_rows: 272238\n",
417
+ " })\n",
418
+ " val: Dataset({\n",
419
+ " features: ['id', 'text', 'label'],\n",
420
+ " num_rows: 272238\n",
421
+ " })\n",
422
+ "})"
423
+ ]
424
+ },
425
+ "execution_count": 8,
426
+ "metadata": {},
427
+ "output_type": "execute_result"
428
+ }
429
+ ],
430
+ "source": [
431
+ "dataset"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "code",
436
+ "execution_count": 8,
437
+ "metadata": {
438
+ "gather": {
439
+ "logged": 1706503446498
440
+ }
441
+ },
442
+ "outputs": [],
443
+ "source": [
444
+ "SUBSAMPLING = 0.01\n",
445
+ "\n",
446
+ "if SUBSAMPLING < 1:\n",
447
+ " _ = DatasetDict()\n",
448
+ " for each in dataset.keys():\n",
449
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
450
+ "\n",
451
+ " dataset = _"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "markdown",
456
+ "metadata": {},
457
+ "source": [
458
+ "## Tokenisation and encoding"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 10,
464
+ "metadata": {
465
+ "gather": {
466
+ "logged": 1706503446633
467
+ }
468
+ },
469
+ "outputs": [],
470
+ "source": [
471
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
472
+ " return ds_enc"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "markdown",
477
+ "metadata": {},
478
+ "source": [
479
+ "## Evaluation metrics"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": 9,
485
+ "metadata": {
486
+ "gather": {
487
+ "logged": 1706503446863
488
+ }
489
+ },
490
+ "outputs": [],
491
+ "source": [
492
+ "accuracy = evaluate.load(\"accuracy\")\n",
493
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
494
+ "f1 = evaluate.load(\"f1\")"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": 10,
500
+ "metadata": {
501
+ "gather": {
502
+ "logged": 1706503447004
503
+ }
504
+ },
505
+ "outputs": [],
506
+ "source": [
507
+ "def compute_metrics(eval_pred):\n",
508
+ " predictions, labels = eval_pred\n",
509
+ " predictions = np.argmax(predictions, axis=1)\n",
510
+ " return {\n",
511
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
512
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
513
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
514
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
515
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
516
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
517
+ " }"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "markdown",
522
+ "metadata": {},
523
+ "source": [
524
+ "## Training"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "markdown",
529
+ "metadata": {},
530
+ "source": [
531
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
532
+ ]
533
+ },
534
+ {
535
+ "cell_type": "code",
536
+ "execution_count": 11,
537
+ "metadata": {
538
+ "gather": {
539
+ "logged": 1706503447186
540
+ }
541
+ },
542
+ "outputs": [],
543
+ "source": [
544
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": 12,
550
+ "metadata": {
551
+ "jupyter": {
552
+ "outputs_hidden": false,
553
+ "source_hidden": false
554
+ },
555
+ "nteract": {
556
+ "transient": {
557
+ "deleting": false
558
+ }
559
+ }
560
+ },
561
+ "outputs": [],
562
+ "source": [
563
+ "def train_from_model(model_ckpt: str, push: bool = False):\n",
564
+ " print(f\"Initialising training based on {model_ckpt}...\")\n",
565
+ "\n",
566
+ " print(\"Tokenising...\")\n",
567
+ " tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
568
+ "\n",
569
+ " cols = dataset[\"train\"].column_names\n",
570
+ " cols.remove(\"label\")\n",
571
+ " ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
572
+ "\n",
573
+ " print(\"Loading model...\")\n",
574
+ " try:\n",
575
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
576
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
577
+ " id2label=label_map, \n",
578
+ " label2id={v:k for k,v in label_map.items()})\n",
579
+ " except OSError:\n",
580
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
581
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
582
+ " id2label=label_map, \n",
583
+ " label2id={v:k for k,v in label_map.items()},\n",
584
+ " from_tf=True)\n",
585
+ "\n",
586
+ "\n",
587
+ " args = TrainingArguments(\n",
588
+ " output_dir=\"vaers\",\n",
589
+ " evaluation_strategy=\"epoch\",\n",
590
+ " save_strategy=\"epoch\",\n",
591
+ " learning_rate=2e-5,\n",
592
+ " per_device_train_batch_size=BATCH_SIZE,\n",
593
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
594
+ " num_train_epochs=EPOCHS,\n",
595
+ " weight_decay=.01,\n",
596
+ " logging_steps=1,\n",
597
+ " load_best_model_at_end=True,\n",
598
+ " run_name=f\"daedra-training\",\n",
599
+ " report_to=[\"wandb\"])\n",
600
+ "\n",
601
+ " trainer = Trainer(\n",
602
+ " model=model,\n",
603
+ " args=args,\n",
604
+ " train_dataset=ds_enc[\"train\"],\n",
605
+ " eval_dataset=ds_enc[\"test\"],\n",
606
+ " tokenizer=tokenizer,\n",
607
+ " compute_metrics=compute_metrics)\n",
608
+ " \n",
609
+ " if SUBSAMPLING != 1.0:\n",
610
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
611
+ " else:\n",
612
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
613
+ "\n",
614
+ " wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
615
+ " wandb_tag.append(f\"base:{model_ckpt}\")\n",
616
+ " \n",
617
+ " wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
618
+ "\n",
619
+ " print(\"Starting training...\")\n",
620
+ "\n",
621
+ " trainer.train()\n",
622
+ "\n",
623
+ " print(\"Training finished.\")\n",
624
+ "\n",
625
+ " if push:\n",
626
+ " variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
627
+ " tokenizer._tokenizer.save(\"tokenizer.json\")\n",
628
+ " tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
629
+ " sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
630
+ "\n",
631
+ " model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
632
+ " variant=variant,\n",
633
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": 13,
639
+ "metadata": {
640
+ "gather": {
641
+ "logged": 1706503552083
642
+ }
643
+ },
644
+ "outputs": [],
645
+ "source": [
646
+ "\n",
647
+ "base_models = [\n",
648
+ " \"bert-base-uncased\",\n",
649
+ " \"distilbert-base-uncased\",\n",
650
+ "]"
651
+ ]
652
+ }
653
+ ],
654
+ "metadata": {
655
+ "datalore": {
656
+ "base_environment": "default",
657
+ "computation_mode": "JUPYTER",
658
+ "package_manager": "pip",
659
+ "packages": [
660
+ {
661
+ "name": "datasets",
662
+ "source": "PIP",
663
+ "version": "2.16.1"
664
+ },
665
+ {
666
+ "name": "torch",
667
+ "source": "PIP",
668
+ "version": "2.1.2"
669
+ },
670
+ {
671
+ "name": "accelerate",
672
+ "source": "PIP",
673
+ "version": "0.26.1"
674
+ }
675
+ ],
676
+ "report_row_ids": [
677
+ "un8W7ez7ZwoGb5Co6nydEV",
678
+ "40nN9Hvgi1clHNV5RAemI5",
679
+ "TgRD90H5NSPpKS41OeXI1w",
680
+ "ZOm5BfUs3h1EGLaUkBGeEB",
681
+ "kOP0CZWNSk6vqE3wkPp7Vc",
682
+ "W4PWcOu2O2pRaZyoE2W80h",
683
+ "RolbOnQLIftk0vy9mIcz5M",
684
+ "8OPhUgbaNJmOdiq5D3a6vK",
685
+ "5Qrt3jSvSrpK6Ne1hS6shL",
686
+ "hTq7nFUrovN5Ao4u6dIYWZ",
687
+ "I8WNZLpJ1DVP2wiCW7YBIB",
688
+ "SawhU3I9BewSE1XBPstpNJ",
689
+ "80EtLEl2FIE4FqbWnUD3nT"
690
+ ],
691
+ "version": 3
692
+ },
693
+ "kernel_info": {
694
+ "name": "python38-azureml-pt-tf"
695
+ },
696
+ "kernelspec": {
697
+ "display_name": "azureml_py38_PT_TF",
698
+ "language": "python",
699
+ "name": "python3"
700
+ },
701
+ "language_info": {
702
+ "codemirror_mode": {
703
+ "name": "ipython",
704
+ "version": 3
705
+ },
706
+ "file_extension": ".py",
707
+ "mimetype": "text/x-python",
708
+ "name": "python",
709
+ "nbconvert_exporter": "python",
710
+ "pygments_lexer": "ipython3",
711
+ "version": "3.8.5"
712
+ },
713
+ "microsoft": {
714
+ "host": {
715
+ "AzureML": {
716
+ "notebookHasBeenCompleted": true
717
+ }
718
+ },
719
+ "ms_spell_check": {
720
+ "ms_spell_check_language": "en"
721
+ }
722
+ },
723
+ "nteract": {
724
+ "version": "nteract-front-end@1.0.0"
725
+ }
726
+ },
727
+ "nbformat": 4,
728
+ "nbformat_minor": 4
729
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-17-44-52Z.ipynb ADDED
@@ -0,0 +1,739 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
8
+ "\n",
9
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "gather": {
17
+ "logged": 1706475754655
18
+ },
19
+ "nteract": {
20
+ "transient": {
21
+ "deleting": false
22
+ }
23
+ },
24
+ "tags": []
25
+ },
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
32
+ "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
33
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
34
+ "Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
35
+ "Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
36
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
37
+ "Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
38
+ "Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
39
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
40
+ "Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
41
+ "Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
42
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
43
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
44
+ "Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
45
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
46
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
47
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
48
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
49
+ "Note: you may need to restart the kernel to use updated packages.\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "%pip install accelerate -U"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 2,
60
+ "metadata": {
61
+ "nteract": {
62
+ "transient": {
63
+ "deleting": false
64
+ }
65
+ }
66
+ },
67
+ "outputs": [
68
+ {
69
+ "name": "stdout",
70
+ "output_type": "stream",
71
+ "text": [
72
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
73
+ "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
74
+ "Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
75
+ "Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
76
+ "Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
77
+ "Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
78
+ "Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
79
+ "Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
80
+ "Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
81
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
82
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
83
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
84
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
85
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
86
+ "Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
87
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
88
+ "Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
89
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
90
+ "Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
91
+ "Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
92
+ "Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
93
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
94
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
95
+ "Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
96
+ "Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
97
+ "Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
98
+ "Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
99
+ "Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
100
+ "Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
101
+ "Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
102
+ "Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
103
+ "Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
104
+ "Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
105
+ "Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
106
+ "Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
107
+ "Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
108
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
109
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
110
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
111
+ "Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
112
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
113
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
114
+ "Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
115
+ "Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
116
+ "Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
117
+ "Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
118
+ "Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
119
+ "Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
120
+ "Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
121
+ "Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
122
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
123
+ "Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
124
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
125
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
126
+ "Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
127
+ "Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
128
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
129
+ "Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
130
+ "Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
131
+ "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
132
+ "Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
133
+ "Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
134
+ "Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
135
+ "Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
136
+ "Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
137
+ "Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
138
+ "Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
139
+ "Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
140
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
141
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
142
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
143
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
144
+ "Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
145
+ "Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
146
+ "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
147
+ "Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
148
+ "Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
149
+ "Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
150
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
151
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
152
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
153
+ "Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
154
+ "Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
155
+ "Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
156
+ "Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
157
+ "Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
158
+ "Note: you may need to restart the kernel to use updated packages.\n"
159
+ ]
160
+ }
161
+ ],
162
+ "source": [
163
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 28,
169
+ "metadata": {
170
+ "datalore": {
171
+ "hide_input_from_viewers": false,
172
+ "hide_output_from_viewers": false,
173
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
174
+ "report_properties": {
175
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
176
+ },
177
+ "type": "CODE"
178
+ },
179
+ "gather": {
180
+ "logged": 1706503443742
181
+ },
182
+ "tags": []
183
+ },
184
+ "outputs": [
185
+ {
186
+ "data": {
187
+ "text/html": [
188
+ " View run <strong style=\"color:#cdcd00\">daedra_0.05-distilbert-base-uncased</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
189
+ ],
190
+ "text/plain": [
191
+ "<IPython.core.display.HTML object>"
192
+ ]
193
+ },
194
+ "metadata": {},
195
+ "output_type": "display_data"
196
+ },
197
+ {
198
+ "data": {
199
+ "text/html": [
200
+ "Find logs at: <code>./wandb/run-20240129_152136-cwkdl3x7/logs</code>"
201
+ ],
202
+ "text/plain": [
203
+ "<IPython.core.display.HTML object>"
204
+ ]
205
+ },
206
+ "metadata": {},
207
+ "output_type": "display_data"
208
+ },
209
+ {
210
+ "name": "stdout",
211
+ "output_type": "stream",
212
+ "text": [
213
+ "The watermark extension is already loaded. To reload it, use:\n",
214
+ " %reload_ext watermark\n"
215
+ ]
216
+ }
217
+ ],
218
+ "source": [
219
+ "import pandas as pd\n",
220
+ "import numpy as np\n",
221
+ "import torch\n",
222
+ "import os\n",
223
+ "from typing import List, Union\n",
224
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
225
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
226
+ "import shap\n",
227
+ "import wandb\n",
228
+ "import evaluate\n",
229
+ "import logging\n",
230
+ "\n",
231
+ "wandb.finish()\n",
232
+ "\n",
233
+ "\n",
234
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
235
+ "\n",
236
+ "%load_ext watermark"
237
+ ]
238
+ },
239
+ {
240
+ "cell_type": "code",
241
+ "execution_count": 4,
242
+ "metadata": {
243
+ "collapsed": false,
244
+ "gather": {
245
+ "logged": 1706503443899
246
+ },
247
+ "jupyter": {
248
+ "outputs_hidden": false
249
+ }
250
+ },
251
+ "outputs": [],
252
+ "source": [
253
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
254
+ "\n",
255
+ "SEED: int = 42\n",
256
+ "\n",
257
+ "BATCH_SIZE: int = 32\n",
258
+ "EPOCHS: int = 5\n",
259
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
260
+ "\n",
261
+ "# WandB configuration\n",
262
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
263
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
264
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": 5,
270
+ "metadata": {
271
+ "collapsed": false,
272
+ "jupyter": {
273
+ "outputs_hidden": false
274
+ }
275
+ },
276
+ "outputs": [
277
+ {
278
+ "name": "stdout",
279
+ "output_type": "stream",
280
+ "text": [
281
+ "re : 2.2.1\n",
282
+ "torch : 1.12.0\n",
283
+ "wandb : 0.16.2\n",
284
+ "logging : 0.5.1.2\n",
285
+ "numpy : 1.23.5\n",
286
+ "pandas : 2.0.2\n",
287
+ "evaluate: 0.4.1\n",
288
+ "shap : 0.44.1\n",
289
+ "\n"
290
+ ]
291
+ }
292
+ ],
293
+ "source": [
294
+ "%watermark --iversion"
295
+ ]
296
+ },
297
+ {
298
+ "cell_type": "code",
299
+ "execution_count": 6,
300
+ "metadata": {
301
+ "datalore": {
302
+ "hide_input_from_viewers": true,
303
+ "hide_output_from_viewers": true,
304
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
305
+ "type": "CODE"
306
+ }
307
+ },
308
+ "outputs": [
309
+ {
310
+ "name": "stdout",
311
+ "output_type": "stream",
312
+ "text": [
313
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
314
+ "Mon Jan 29 15:20:22 2024 \n",
315
+ "+---------------------------------------------------------------------------------------+\n",
316
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
317
+ "|-----------------------------------------+----------------------+----------------------+\n",
318
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
319
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
320
+ "| | | MIG M. |\n",
321
+ "|=========================================+======================+======================|\n",
322
+ "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
323
+ "| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
324
+ "| | | N/A |\n",
325
+ "+-----------------------------------------+----------------------+----------------------+\n",
326
+ "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
327
+ "| N/A 26C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\n",
328
+ "| | | N/A |\n",
329
+ "+-----------------------------------------+----------------------+----------------------+\n",
330
+ "| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\n",
331
+ "| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
332
+ "| | | N/A |\n",
333
+ "+-----------------------------------------+----------------------+----------------------+\n",
334
+ "| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\n",
335
+ "| N/A 28C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
336
+ "| | | N/A |\n",
337
+ "+-----------------------------------------+----------------------+----------------------+\n",
338
+ " \n",
339
+ "+---------------------------------------------------------------------------------------+\n",
340
+ "| Processes: |\n",
341
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
342
+ "| ID ID Usage |\n",
343
+ "|=======================================================================================|\n",
344
+ "| No running processes found |\n",
345
+ "+---------------------------------------------------------------------------------------+\n"
346
+ ]
347
+ }
348
+ ],
349
+ "source": [
350
+ "!nvidia-smi"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "markdown",
355
+ "metadata": {
356
+ "datalore": {
357
+ "hide_input_from_viewers": false,
358
+ "hide_output_from_viewers": false,
359
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
360
+ "report_properties": {
361
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
362
+ },
363
+ "type": "MD"
364
+ }
365
+ },
366
+ "source": [
367
+ "## Loading the data set"
368
+ ]
369
+ },
370
+ {
371
+ "cell_type": "code",
372
+ "execution_count": 7,
373
+ "metadata": {
374
+ "collapsed": false,
375
+ "gather": {
376
+ "logged": 1706503446033
377
+ },
378
+ "jupyter": {
379
+ "outputs_hidden": false
380
+ }
381
+ },
382
+ "outputs": [],
383
+ "source": [
384
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
385
+ ]
386
+ },
387
+ {
388
+ "cell_type": "code",
389
+ "execution_count": 8,
390
+ "metadata": {
391
+ "collapsed": false,
392
+ "gather": {
393
+ "logged": 1706503446252
394
+ },
395
+ "jupyter": {
396
+ "outputs_hidden": false,
397
+ "source_hidden": false
398
+ },
399
+ "nteract": {
400
+ "transient": {
401
+ "deleting": false
402
+ }
403
+ }
404
+ },
405
+ "outputs": [
406
+ {
407
+ "data": {
408
+ "text/plain": [
409
+ "DatasetDict({\n",
410
+ " train: Dataset({\n",
411
+ " features: ['id', 'text', 'label'],\n",
412
+ " num_rows: 1270444\n",
413
+ " })\n",
414
+ " test: Dataset({\n",
415
+ " features: ['id', 'text', 'label'],\n",
416
+ " num_rows: 272238\n",
417
+ " })\n",
418
+ " val: Dataset({\n",
419
+ " features: ['id', 'text', 'label'],\n",
420
+ " num_rows: 272238\n",
421
+ " })\n",
422
+ "})"
423
+ ]
424
+ },
425
+ "execution_count": 8,
426
+ "metadata": {},
427
+ "output_type": "execute_result"
428
+ }
429
+ ],
430
+ "source": [
431
+ "dataset"
432
+ ]
433
+ },
434
+ {
435
+ "cell_type": "code",
436
+ "execution_count": 8,
437
+ "metadata": {
438
+ "gather": {
439
+ "logged": 1706503446498
440
+ }
441
+ },
442
+ "outputs": [],
443
+ "source": [
444
+ "SUBSAMPLING = 0.01\n",
445
+ "\n",
446
+ "if SUBSAMPLING < 1:\n",
447
+ " _ = DatasetDict()\n",
448
+ " for each in dataset.keys():\n",
449
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
450
+ "\n",
451
+ " dataset = _"
452
+ ]
453
+ },
454
+ {
455
+ "cell_type": "markdown",
456
+ "metadata": {},
457
+ "source": [
458
+ "## Tokenisation and encoding"
459
+ ]
460
+ },
461
+ {
462
+ "cell_type": "code",
463
+ "execution_count": 10,
464
+ "metadata": {
465
+ "gather": {
466
+ "logged": 1706503446633
467
+ }
468
+ },
469
+ "outputs": [],
470
+ "source": [
471
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
472
+ " return ds_enc"
473
+ ]
474
+ },
475
+ {
476
+ "cell_type": "markdown",
477
+ "metadata": {},
478
+ "source": [
479
+ "## Evaluation metrics"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": 9,
485
+ "metadata": {
486
+ "gather": {
487
+ "logged": 1706503446863
488
+ }
489
+ },
490
+ "outputs": [],
491
+ "source": [
492
+ "accuracy = evaluate.load(\"accuracy\")\n",
493
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
494
+ "f1 = evaluate.load(\"f1\")"
495
+ ]
496
+ },
497
+ {
498
+ "cell_type": "code",
499
+ "execution_count": 10,
500
+ "metadata": {
501
+ "gather": {
502
+ "logged": 1706503447004
503
+ }
504
+ },
505
+ "outputs": [],
506
+ "source": [
507
+ "def compute_metrics(eval_pred):\n",
508
+ " predictions, labels = eval_pred\n",
509
+ " predictions = np.argmax(predictions, axis=1)\n",
510
+ " return {\n",
511
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
512
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
513
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
514
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
515
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
516
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
517
+ " }"
518
+ ]
519
+ },
520
+ {
521
+ "cell_type": "markdown",
522
+ "metadata": {},
523
+ "source": [
524
+ "## Training"
525
+ ]
526
+ },
527
+ {
528
+ "cell_type": "markdown",
529
+ "metadata": {},
530
+ "source": [
531
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
532
+ ]
533
+ },
534
+ {
535
+ "cell_type": "code",
536
+ "execution_count": 11,
537
+ "metadata": {
538
+ "gather": {
539
+ "logged": 1706503447186
540
+ }
541
+ },
542
+ "outputs": [],
543
+ "source": [
544
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
545
+ ]
546
+ },
547
+ {
548
+ "cell_type": "code",
549
+ "execution_count": 12,
550
+ "metadata": {
551
+ "jupyter": {
552
+ "outputs_hidden": false,
553
+ "source_hidden": false
554
+ },
555
+ "nteract": {
556
+ "transient": {
557
+ "deleting": false
558
+ }
559
+ }
560
+ },
561
+ "outputs": [],
562
+ "source": [
563
+ "def train_from_model(model_ckpt: str, push: bool = False):\n",
564
+ " print(f\"Initialising training based on {model_ckpt}...\")\n",
565
+ "\n",
566
+ " print(\"Tokenising...\")\n",
567
+ " tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
568
+ "\n",
569
+ " cols = dataset[\"train\"].column_names\n",
570
+ " cols.remove(\"label\")\n",
571
+ " ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
572
+ "\n",
573
+ " print(\"Loading model...\")\n",
574
+ " try:\n",
575
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
576
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
577
+ " id2label=label_map, \n",
578
+ " label2id={v:k for k,v in label_map.items()})\n",
579
+ " except OSError:\n",
580
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
581
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
582
+ " id2label=label_map, \n",
583
+ " label2id={v:k for k,v in label_map.items()},\n",
584
+ " from_tf=True)\n",
585
+ "\n",
586
+ "\n",
587
+ " args = TrainingArguments(\n",
588
+ " output_dir=\"vaers\",\n",
589
+ " evaluation_strategy=\"epoch\",\n",
590
+ " save_strategy=\"epoch\",\n",
591
+ " learning_rate=2e-5,\n",
592
+ " per_device_train_batch_size=BATCH_SIZE,\n",
593
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
594
+ " num_train_epochs=EPOCHS,\n",
595
+ " weight_decay=.01,\n",
596
+ " logging_steps=1,\n",
597
+ " load_best_model_at_end=True,\n",
598
+ " run_name=f\"daedra-training\",\n",
599
+ " report_to=[\"wandb\"])\n",
600
+ "\n",
601
+ " trainer = Trainer(\n",
602
+ " model=model,\n",
603
+ " args=args,\n",
604
+ " train_dataset=ds_enc[\"train\"],\n",
605
+ " eval_dataset=ds_enc[\"test\"],\n",
606
+ " tokenizer=tokenizer,\n",
607
+ " compute_metrics=compute_metrics)\n",
608
+ " \n",
609
+ " if SUBSAMPLING != 1.0:\n",
610
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
611
+ " else:\n",
612
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
613
+ "\n",
614
+ " wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
615
+ " wandb_tag.append(f\"base:{model_ckpt}\")\n",
616
+ " \n",
617
+ " wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
618
+ "\n",
619
+ " print(\"Starting training...\")\n",
620
+ "\n",
621
+ " trainer.train()\n",
622
+ "\n",
623
+ " print(\"Training finished.\")\n",
624
+ "\n",
625
+ " if push:\n",
626
+ " variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
627
+ " tokenizer._tokenizer.save(\"tokenizer.json\")\n",
628
+ " tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
629
+ " sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
630
+ "\n",
631
+ " model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
632
+ " variant=variant,\n",
633
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
634
+ ]
635
+ },
636
+ {
637
+ "cell_type": "code",
638
+ "execution_count": 13,
639
+ "metadata": {
640
+ "gather": {
641
+ "logged": 1706503552083
642
+ }
643
+ },
644
+ "outputs": [],
645
+ "source": [
646
+ "\n",
647
+ "base_models = [\n",
648
+ " \"bert-base-uncased\",\n",
649
+ " \"distilbert-base-uncased\",\n",
650
+ "]"
651
+ ]
652
+ },
653
+ {
654
+ "cell_type": "code",
655
+ "execution_count": null,
656
+ "metadata": {},
657
+ "outputs": [],
658
+ "source": [
659
+ "for md in base_models:\n",
660
+ " train_from_model(md)"
661
+ ]
662
+ }
663
+ ],
664
+ "metadata": {
665
+ "datalore": {
666
+ "base_environment": "default",
667
+ "computation_mode": "JUPYTER",
668
+ "package_manager": "pip",
669
+ "packages": [
670
+ {
671
+ "name": "datasets",
672
+ "source": "PIP",
673
+ "version": "2.16.1"
674
+ },
675
+ {
676
+ "name": "torch",
677
+ "source": "PIP",
678
+ "version": "2.1.2"
679
+ },
680
+ {
681
+ "name": "accelerate",
682
+ "source": "PIP",
683
+ "version": "0.26.1"
684
+ }
685
+ ],
686
+ "report_row_ids": [
687
+ "un8W7ez7ZwoGb5Co6nydEV",
688
+ "40nN9Hvgi1clHNV5RAemI5",
689
+ "TgRD90H5NSPpKS41OeXI1w",
690
+ "ZOm5BfUs3h1EGLaUkBGeEB",
691
+ "kOP0CZWNSk6vqE3wkPp7Vc",
692
+ "W4PWcOu2O2pRaZyoE2W80h",
693
+ "RolbOnQLIftk0vy9mIcz5M",
694
+ "8OPhUgbaNJmOdiq5D3a6vK",
695
+ "5Qrt3jSvSrpK6Ne1hS6shL",
696
+ "hTq7nFUrovN5Ao4u6dIYWZ",
697
+ "I8WNZLpJ1DVP2wiCW7YBIB",
698
+ "SawhU3I9BewSE1XBPstpNJ",
699
+ "80EtLEl2FIE4FqbWnUD3nT"
700
+ ],
701
+ "version": 3
702
+ },
703
+ "kernel_info": {
704
+ "name": "python38-azureml-pt-tf"
705
+ },
706
+ "kernelspec": {
707
+ "display_name": "azureml_py38_PT_TF",
708
+ "language": "python",
709
+ "name": "python3"
710
+ },
711
+ "language_info": {
712
+ "codemirror_mode": {
713
+ "name": "ipython",
714
+ "version": 3
715
+ },
716
+ "file_extension": ".py",
717
+ "mimetype": "text/x-python",
718
+ "name": "python",
719
+ "nbconvert_exporter": "python",
720
+ "pygments_lexer": "ipython3",
721
+ "version": "3.8.5"
722
+ },
723
+ "microsoft": {
724
+ "host": {
725
+ "AzureML": {
726
+ "notebookHasBeenCompleted": true
727
+ }
728
+ },
729
+ "ms_spell_check": {
730
+ "ms_spell_check_language": "en"
731
+ }
732
+ },
733
+ "nteract": {
734
+ "version": "nteract-front-end@1.0.0"
735
+ }
736
+ },
737
+ "nbformat": 4,
738
+ "nbformat_minor": 4
739
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-3-40-27Z.ipynb ADDED
@@ -0,0 +1,1001 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
8
+ "\n",
9
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "gather": {
17
+ "logged": 1706475754655
18
+ },
19
+ "nteract": {
20
+ "transient": {
21
+ "deleting": false
22
+ }
23
+ },
24
+ "tags": []
25
+ },
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
32
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
33
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
34
+ "Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
35
+ "Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
36
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
37
+ "Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
38
+ "Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
39
+ "Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
40
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
41
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
42
+ "Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
43
+ "Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
44
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
45
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
46
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
47
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
48
+ "Note: you may need to restart the kernel to use updated packages.\n"
49
+ ]
50
+ }
51
+ ],
52
+ "source": [
53
+ "%pip install accelerate -U"
54
+ ]
55
+ },
56
+ {
57
+ "cell_type": "code",
58
+ "execution_count": 2,
59
+ "metadata": {
60
+ "nteract": {
61
+ "transient": {
62
+ "deleting": false
63
+ }
64
+ }
65
+ },
66
+ "outputs": [
67
+ {
68
+ "name": "stdout",
69
+ "output_type": "stream",
70
+ "text": [
71
+ "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
72
+ "Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
73
+ "Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
74
+ "Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
75
+ "Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
76
+ "Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
77
+ "Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
78
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
79
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
80
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
81
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
82
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
83
+ "Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
84
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
85
+ "Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
86
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
87
+ "Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
88
+ "Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
89
+ "Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
90
+ "Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
91
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
92
+ "Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
93
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
94
+ "Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
95
+ "Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
96
+ "Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
97
+ "Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
98
+ "Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
99
+ "Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
100
+ "Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
101
+ "Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
102
+ "Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
103
+ "Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
104
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
105
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
106
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
107
+ "Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
108
+ "Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
109
+ "Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
110
+ "Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
111
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
112
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
113
+ "Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
114
+ "Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
115
+ "Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
116
+ "Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
117
+ "Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
118
+ "Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
119
+ "Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
120
+ "Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
121
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
122
+ "Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
123
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
124
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
125
+ "Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
126
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
127
+ "Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
128
+ "Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
129
+ "Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
130
+ "Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
131
+ "Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
132
+ "Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
133
+ "Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
134
+ "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
135
+ "Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
136
+ "Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
137
+ "Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
138
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
139
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
140
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
141
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
142
+ "Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
143
+ "Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
144
+ "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
145
+ "Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
146
+ "Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
147
+ "Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
148
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
149
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
150
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
151
+ "Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
152
+ "Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
153
+ "Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
154
+ "Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
155
+ "Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
156
+ "Note: you may need to restart the kernel to use updated packages.\n"
157
+ ]
158
+ }
159
+ ],
160
+ "source": [
161
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
162
+ ]
163
+ },
164
+ {
165
+ "cell_type": "code",
166
+ "execution_count": 3,
167
+ "metadata": {
168
+ "datalore": {
169
+ "hide_input_from_viewers": false,
170
+ "hide_output_from_viewers": false,
171
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
172
+ "report_properties": {
173
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
174
+ },
175
+ "type": "CODE"
176
+ },
177
+ "gather": {
178
+ "logged": 1706486372154
179
+ },
180
+ "tags": []
181
+ },
182
+ "outputs": [
183
+ {
184
+ "name": "stderr",
185
+ "output_type": "stream",
186
+ "text": [
187
+ "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
188
+ " from .autonotebook import tqdm as notebook_tqdm\n",
189
+ "2024-01-28 23:59:27.034680: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
190
+ "To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
191
+ "2024-01-28 23:59:27.996419: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
192
+ "2024-01-28 23:59:27.999143: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
193
+ "2024-01-28 23:59:27.999161: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
194
+ "[codecarbon INFO @ 23:59:30] [setup] RAM Tracking...\n",
195
+ "[codecarbon INFO @ 23:59:30] [setup] GPU Tracking...\n",
196
+ "[codecarbon INFO @ 23:59:30] Tracking Nvidia GPU via pynvml\n",
197
+ "[codecarbon INFO @ 23:59:30] [setup] CPU Tracking...\n",
198
+ "[codecarbon WARNING @ 23:59:30] No CPU tracking mode found. Falling back on CPU constant mode.\n",
199
+ "[codecarbon WARNING @ 23:59:31] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n",
200
+ "[codecarbon INFO @ 23:59:31] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
201
+ "[codecarbon INFO @ 23:59:31] >>> Tracker's metadata:\n",
202
+ "[codecarbon INFO @ 23:59:31] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n",
203
+ "[codecarbon INFO @ 23:59:31] Python version: 3.8.5\n",
204
+ "[codecarbon INFO @ 23:59:31] CodeCarbon version: 2.3.3\n",
205
+ "[codecarbon INFO @ 23:59:31] Available RAM : 440.883 GB\n",
206
+ "[codecarbon INFO @ 23:59:31] CPU count: 24\n",
207
+ "[codecarbon INFO @ 23:59:31] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
208
+ "[codecarbon INFO @ 23:59:31] GPU count: 4\n",
209
+ "[codecarbon INFO @ 23:59:31] GPU model: 4 x Tesla V100-PCIE-16GB\n",
210
+ "[codecarbon WARNING @ 23:59:32] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
211
+ ]
212
+ }
213
+ ],
214
+ "source": [
215
+ "import pandas as pd\n",
216
+ "import numpy as np\n",
217
+ "import torch\n",
218
+ "import os\n",
219
+ "from typing import List, Union\n",
220
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
221
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
222
+ "import shap\n",
223
+ "import wandb\n",
224
+ "import evaluate\n",
225
+ "from codecarbon import EmissionsTracker\n",
226
+ "import logging\n",
227
+ "\n",
228
+ "logging.getLogger('codecarbon').propagate = False\n",
229
+ "\n",
230
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
231
+ "tracker = EmissionsTracker()\n",
232
+ "\n",
233
+ "%load_ext watermark"
234
+ ]
235
+ },
236
+ {
237
+ "cell_type": "code",
238
+ "execution_count": 4,
239
+ "metadata": {
240
+ "collapsed": false,
241
+ "gather": {
242
+ "logged": 1706486372304
243
+ },
244
+ "jupyter": {
245
+ "outputs_hidden": false
246
+ }
247
+ },
248
+ "outputs": [],
249
+ "source": [
250
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
251
+ "\n",
252
+ "SEED: int = 42\n",
253
+ "\n",
254
+ "BATCH_SIZE: int = 32\n",
255
+ "EPOCHS: int = 3\n",
256
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
257
+ "\n",
258
+ "# WandB configuration\n",
259
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
260
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
261
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
262
+ ]
263
+ },
264
+ {
265
+ "cell_type": "code",
266
+ "execution_count": 5,
267
+ "metadata": {
268
+ "collapsed": false,
269
+ "jupyter": {
270
+ "outputs_hidden": false
271
+ }
272
+ },
273
+ "outputs": [
274
+ {
275
+ "name": "stdout",
276
+ "output_type": "stream",
277
+ "text": [
278
+ "re : 2.2.1\n",
279
+ "evaluate: 0.4.1\n",
280
+ "pandas : 2.0.2\n",
281
+ "wandb : 0.16.2\n",
282
+ "numpy : 1.23.5\n",
283
+ "torch : 1.12.0\n",
284
+ "logging : 0.5.1.2\n",
285
+ "shap : 0.44.1\n",
286
+ "\n"
287
+ ]
288
+ }
289
+ ],
290
+ "source": [
291
+ "%watermark --iversion"
292
+ ]
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "execution_count": 6,
297
+ "metadata": {
298
+ "datalore": {
299
+ "hide_input_from_viewers": true,
300
+ "hide_output_from_viewers": true,
301
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
302
+ "type": "CODE"
303
+ }
304
+ },
305
+ "outputs": [
306
+ {
307
+ "name": "stdout",
308
+ "output_type": "stream",
309
+ "text": [
310
+ "Sun Jan 28 23:59:32 2024 \r\n",
311
+ "+---------------------------------------------------------------------------------------+\r\n",
312
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n",
313
+ "|-----------------------------------------+----------------------+----------------------+\r\n",
314
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n",
315
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n",
316
+ "| | | MIG M. |\r\n",
317
+ "|=========================================+======================+======================|\r\n",
318
+ "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n",
319
+ "| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
320
+ "| | | N/A |\r\n",
321
+ "+-----------------------------------------+----------------------+----------------------+\r\n",
322
+ "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n",
323
+ "| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
324
+ "| | | N/A |\r\n",
325
+ "+-----------------------------------------+----------------------+----------------------+\r\n",
326
+ "| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n",
327
+ "| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
328
+ "| | | N/A |\r\n",
329
+ "+-----------------------------------------+----------------------+----------------------+\r\n",
330
+ "| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n",
331
+ "| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
332
+ "| | | N/A |\r\n",
333
+ "+-----------------------------------------+----------------------+----------------------+\r\n",
334
+ " \r\n",
335
+ "+---------------------------------------------------------------------------------------+\r\n",
336
+ "| Processes: |\r\n",
337
+ "| GPU GI CI PID Type Process name GPU Memory |\r\n",
338
+ "| ID ID Usage |\r\n",
339
+ "|=======================================================================================|\r\n",
340
+ "| No running processes found |\r\n",
341
+ "+---------------------------------------------------------------------------------------+\r\n"
342
+ ]
343
+ }
344
+ ],
345
+ "source": [
346
+ "!nvidia-smi"
347
+ ]
348
+ },
349
+ {
350
+ "cell_type": "markdown",
351
+ "metadata": {
352
+ "datalore": {
353
+ "hide_input_from_viewers": false,
354
+ "hide_output_from_viewers": false,
355
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
356
+ "report_properties": {
357
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
358
+ },
359
+ "type": "MD"
360
+ }
361
+ },
362
+ "source": [
363
+ "## Loading the data set"
364
+ ]
365
+ },
366
+ {
367
+ "cell_type": "code",
368
+ "execution_count": 7,
369
+ "metadata": {
370
+ "collapsed": false,
371
+ "gather": {
372
+ "logged": 1706486373931
373
+ },
374
+ "jupyter": {
375
+ "outputs_hidden": false
376
+ }
377
+ },
378
+ "outputs": [],
379
+ "source": [
380
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
381
+ ]
382
+ },
383
+ {
384
+ "cell_type": "code",
385
+ "execution_count": 8,
386
+ "metadata": {
387
+ "collapsed": false,
388
+ "gather": {
389
+ "logged": 1706486374218
390
+ },
391
+ "jupyter": {
392
+ "outputs_hidden": false,
393
+ "source_hidden": false
394
+ },
395
+ "nteract": {
396
+ "transient": {
397
+ "deleting": false
398
+ }
399
+ }
400
+ },
401
+ "outputs": [
402
+ {
403
+ "data": {
404
+ "text/plain": [
405
+ "DatasetDict({\n",
406
+ " train: Dataset({\n",
407
+ " features: ['id', 'text', 'label'],\n",
408
+ " num_rows: 1270444\n",
409
+ " })\n",
410
+ " test: Dataset({\n",
411
+ " features: ['id', 'text', 'label'],\n",
412
+ " num_rows: 272238\n",
413
+ " })\n",
414
+ " val: Dataset({\n",
415
+ " features: ['id', 'text', 'label'],\n",
416
+ " num_rows: 272238\n",
417
+ " })\n",
418
+ "})"
419
+ ]
420
+ },
421
+ "execution_count": 8,
422
+ "metadata": {},
423
+ "output_type": "execute_result"
424
+ }
425
+ ],
426
+ "source": [
427
+ "dataset"
428
+ ]
429
+ },
430
+ {
431
+ "cell_type": "code",
432
+ "execution_count": 9,
433
+ "metadata": {
434
+ "gather": {
435
+ "logged": 1706486374480
436
+ }
437
+ },
438
+ "outputs": [],
439
+ "source": [
440
+ "SUBSAMPLING = 0.5\n",
441
+ "\n",
442
+ "if SUBSAMPLING < 1:\n",
443
+ " _ = DatasetDict()\n",
444
+ " for each in dataset.keys():\n",
445
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
446
+ "\n",
447
+ " dataset = _"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "markdown",
452
+ "metadata": {},
453
+ "source": [
454
+ "## Tokenisation and encoding"
455
+ ]
456
+ },
457
+ {
458
+ "cell_type": "code",
459
+ "execution_count": 10,
460
+ "metadata": {
461
+ "gather": {
462
+ "logged": 1706486375030
463
+ }
464
+ },
465
+ "outputs": [],
466
+ "source": [
467
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
468
+ " return ds_enc"
469
+ ]
470
+ },
471
+ {
472
+ "cell_type": "markdown",
473
+ "metadata": {},
474
+ "source": [
475
+ "## Evaluation metrics"
476
+ ]
477
+ },
478
+ {
479
+ "cell_type": "code",
480
+ "execution_count": 11,
481
+ "metadata": {
482
+ "gather": {
483
+ "logged": 1706486375197
484
+ }
485
+ },
486
+ "outputs": [],
487
+ "source": [
488
+ "accuracy = evaluate.load(\"accuracy\")\n",
489
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
490
+ "f1 = evaluate.load(\"f1\")"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": 12,
496
+ "metadata": {
497
+ "gather": {
498
+ "logged": 1706486375361
499
+ }
500
+ },
501
+ "outputs": [],
502
+ "source": [
503
+ "def compute_metrics(eval_pred):\n",
504
+ " predictions, labels = eval_pred\n",
505
+ " predictions = np.argmax(predictions, axis=1)\n",
506
+ " return {\n",
507
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
508
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
509
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
510
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
511
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
512
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
513
+ " }"
514
+ ]
515
+ },
516
+ {
517
+ "cell_type": "markdown",
518
+ "metadata": {},
519
+ "source": [
520
+ "## Training"
521
+ ]
522
+ },
523
+ {
524
+ "cell_type": "markdown",
525
+ "metadata": {},
526
+ "source": [
527
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
528
+ ]
529
+ },
530
+ {
531
+ "cell_type": "code",
532
+ "execution_count": 13,
533
+ "metadata": {
534
+ "gather": {
535
+ "logged": 1706486375569
536
+ }
537
+ },
538
+ "outputs": [],
539
+ "source": [
540
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
541
+ ]
542
+ },
543
+ {
544
+ "cell_type": "code",
545
+ "execution_count": 14,
546
+ "metadata": {
547
+ "gather": {
548
+ "logged": 1706486433708
549
+ }
550
+ },
551
+ "outputs": [
552
+ {
553
+ "name": "stderr",
554
+ "output_type": "stream",
555
+ "text": [
556
+ "Map: 100%|██████████| 136119/136119 [00:56<00:00, 2412.00 examples/s]\n",
557
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
558
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
559
+ ]
560
+ }
561
+ ],
562
+ "source": [
563
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
564
+ "\n",
565
+ "cols = dataset[\"train\"].column_names\n",
566
+ "cols.remove(\"label\")\n",
567
+ "ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n",
568
+ "\n",
569
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
570
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
571
+ " id2label=label_map, \n",
572
+ " label2id={v:k for k,v in label_map.items()})\n",
573
+ "\n",
574
+ "args = TrainingArguments(\n",
575
+ " output_dir=\"vaers\",\n",
576
+ " evaluation_strategy=\"epoch\",\n",
577
+ " save_strategy=\"epoch\",\n",
578
+ " learning_rate=2e-5,\n",
579
+ " per_device_train_batch_size=BATCH_SIZE,\n",
580
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
581
+ " num_train_epochs=EPOCHS,\n",
582
+ " weight_decay=.01,\n",
583
+ " logging_steps=1,\n",
584
+ " load_best_model_at_end=True,\n",
585
+ " run_name=f\"daedra-training\",\n",
586
+ " report_to=[\"wandb\"])\n",
587
+ "\n",
588
+ "trainer = Trainer(\n",
589
+ " model=model,\n",
590
+ " args=args,\n",
591
+ " train_dataset=ds_enc[\"train\"],\n",
592
+ " eval_dataset=ds_enc[\"test\"],\n",
593
+ " tokenizer=tokenizer,\n",
594
+ " compute_metrics=compute_metrics)"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": 15,
600
+ "metadata": {
601
+ "gather": {
602
+ "logged": 1706486444806
603
+ }
604
+ },
605
+ "outputs": [
606
+ {
607
+ "name": "stderr",
608
+ "output_type": "stream",
609
+ "text": [
610
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
611
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
612
+ ]
613
+ },
614
+ {
615
+ "data": {
616
+ "text/html": [
617
+ "Tracking run with wandb version 0.16.2"
618
+ ],
619
+ "text/plain": [
620
+ "<IPython.core.display.HTML object>"
621
+ ]
622
+ },
623
+ "metadata": {},
624
+ "output_type": "display_data"
625
+ },
626
+ {
627
+ "data": {
628
+ "text/html": [
629
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_000035-xm7aguww</code>"
630
+ ],
631
+ "text/plain": [
632
+ "<IPython.core.display.HTML object>"
633
+ ]
634
+ },
635
+ "metadata": {},
636
+ "output_type": "display_data"
637
+ },
638
+ {
639
+ "data": {
640
+ "text/html": [
641
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
642
+ ],
643
+ "text/plain": [
644
+ "<IPython.core.display.HTML object>"
645
+ ]
646
+ },
647
+ "metadata": {},
648
+ "output_type": "display_data"
649
+ },
650
+ {
651
+ "data": {
652
+ "text/html": [
653
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
654
+ ],
655
+ "text/plain": [
656
+ "<IPython.core.display.HTML object>"
657
+ ]
658
+ },
659
+ "metadata": {},
660
+ "output_type": "display_data"
661
+ },
662
+ {
663
+ "data": {
664
+ "text/html": [
665
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww</a>"
666
+ ],
667
+ "text/plain": [
668
+ "<IPython.core.display.HTML object>"
669
+ ]
670
+ },
671
+ "metadata": {},
672
+ "output_type": "display_data"
673
+ },
674
+ {
675
+ "data": {
676
+ "text/html": [
677
+ "Finishing last run (ID:xm7aguww) before initializing another..."
678
+ ],
679
+ "text/plain": [
680
+ "<IPython.core.display.HTML object>"
681
+ ]
682
+ },
683
+ "metadata": {},
684
+ "output_type": "display_data"
685
+ },
686
+ {
687
+ "data": {
688
+ "text/html": [
689
+ " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
690
+ ],
691
+ "text/plain": [
692
+ "<IPython.core.display.HTML object>"
693
+ ]
694
+ },
695
+ "metadata": {},
696
+ "output_type": "display_data"
697
+ },
698
+ {
699
+ "data": {
700
+ "text/html": [
701
+ "Find logs at: <code>./wandb/run-20240129_000035-xm7aguww/logs</code>"
702
+ ],
703
+ "text/plain": [
704
+ "<IPython.core.display.HTML object>"
705
+ ]
706
+ },
707
+ "metadata": {},
708
+ "output_type": "display_data"
709
+ },
710
+ {
711
+ "data": {
712
+ "text/html": [
713
+ "Successfully finished last run (ID:xm7aguww). Initializing new run:<br/>"
714
+ ],
715
+ "text/plain": [
716
+ "<IPython.core.display.HTML object>"
717
+ ]
718
+ },
719
+ "metadata": {},
720
+ "output_type": "display_data"
721
+ },
722
+ {
723
+ "data": {
724
+ "text/html": [
725
+ "Tracking run with wandb version 0.16.2"
726
+ ],
727
+ "text/plain": [
728
+ "<IPython.core.display.HTML object>"
729
+ ]
730
+ },
731
+ "metadata": {},
732
+ "output_type": "display_data"
733
+ },
734
+ {
735
+ "data": {
736
+ "text/html": [
737
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_000037-qfvjuxwm</code>"
738
+ ],
739
+ "text/plain": [
740
+ "<IPython.core.display.HTML object>"
741
+ ]
742
+ },
743
+ "metadata": {},
744
+ "output_type": "display_data"
745
+ },
746
+ {
747
+ "data": {
748
+ "text/html": [
749
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
750
+ ],
751
+ "text/plain": [
752
+ "<IPython.core.display.HTML object>"
753
+ ]
754
+ },
755
+ "metadata": {},
756
+ "output_type": "display_data"
757
+ },
758
+ {
759
+ "data": {
760
+ "text/html": [
761
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
762
+ ],
763
+ "text/plain": [
764
+ "<IPython.core.display.HTML object>"
765
+ ]
766
+ },
767
+ "metadata": {},
768
+ "output_type": "display_data"
769
+ },
770
+ {
771
+ "data": {
772
+ "text/html": [
773
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm</a>"
774
+ ],
775
+ "text/plain": [
776
+ "<IPython.core.display.HTML object>"
777
+ ]
778
+ },
779
+ "metadata": {},
780
+ "output_type": "display_data"
781
+ },
782
+ {
783
+ "data": {
784
+ "text/html": [
785
+ "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
786
+ ],
787
+ "text/plain": [
788
+ "<wandb.sdk.wandb_run.Run at 0x7f6e1c1e64c0>"
789
+ ]
790
+ },
791
+ "execution_count": 15,
792
+ "metadata": {},
793
+ "output_type": "execute_result"
794
+ }
795
+ ],
796
+ "source": [
797
+ "if SUBSAMPLING != 1.0:\n",
798
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
799
+ "else:\n",
800
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
801
+ "\n",
802
+ "wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
803
+ "wandb_tag.append(f\"base:{model_ckpt}\")\n",
804
+ " \n",
805
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
806
+ ]
807
+ },
808
+ {
809
+ "cell_type": "code",
810
+ "execution_count": 16,
811
+ "metadata": {
812
+ "gather": {
813
+ "logged": 1706486541798
814
+ }
815
+ },
816
+ "outputs": [
817
+ {
818
+ "name": "stderr",
819
+ "output_type": "stream",
820
+ "text": [
821
+ "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
822
+ ]
823
+ },
824
+ {
825
+ "data": {
826
+ "text/html": [
827
+ "\n",
828
+ " <div>\n",
829
+ " \n",
830
+ " <progress value='138' max='14889' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
831
+ " [ 138/14889 01:27 < 2:38:40, 1.55 it/s, Epoch 0.03/3]\n",
832
+ " </div>\n",
833
+ " <table border=\"1\" class=\"dataframe\">\n",
834
+ " <thead>\n",
835
+ " <tr style=\"text-align: left;\">\n",
836
+ " <th>Epoch</th>\n",
837
+ " <th>Training Loss</th>\n",
838
+ " <th>Validation Loss</th>\n",
839
+ " </tr>\n",
840
+ " </thead>\n",
841
+ " <tbody>\n",
842
+ " </tbody>\n",
843
+ "</table><p>"
844
+ ],
845
+ "text/plain": [
846
+ "<IPython.core.display.HTML object>"
847
+ ]
848
+ },
849
+ "metadata": {},
850
+ "output_type": "display_data"
851
+ },
852
+ {
853
+ "name": "stderr",
854
+ "output_type": "stream",
855
+ "text": [
856
+ "[codecarbon INFO @ 00:00:59] Energy consumed for RAM : 0.000690 kWh. RAM Power : 165.33123922348022 W\n",
857
+ "[codecarbon INFO @ 00:00:59] Energy consumed for all GPUs : 0.001468 kWh. Total GPU Power : 351.7461416352095 W\n",
858
+ "[codecarbon INFO @ 00:00:59] Energy consumed for all CPUs : 0.000178 kWh. Total CPU Power : 42.5 W\n",
859
+ "[codecarbon INFO @ 00:00:59] 0.002336 kWh of electricity used since the beginning.\n",
860
+ "[codecarbon INFO @ 00:01:14] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n",
861
+ "[codecarbon INFO @ 00:01:14] Energy consumed for all GPUs : 0.004025 kWh. Total GPU Power : 614.3289592286081 W\n",
862
+ "[codecarbon INFO @ 00:01:14] Energy consumed for all CPUs : 0.000355 kWh. Total CPU Power : 42.5 W\n",
863
+ "[codecarbon INFO @ 00:01:14] 0.005757 kWh of electricity used since the beginning.\n",
864
+ "[codecarbon INFO @ 00:01:29] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n",
865
+ "[codecarbon INFO @ 00:01:29] Energy consumed for all GPUs : 0.006586 kWh. Total GPU Power : 615.1209943099732 W\n",
866
+ "[codecarbon INFO @ 00:01:29] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n",
867
+ "[codecarbon INFO @ 00:01:29] 0.009184 kWh of electricity used since the beginning.\n",
868
+ "[codecarbon INFO @ 00:01:44] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n",
869
+ "[codecarbon INFO @ 00:01:44] Energy consumed for all GPUs : 0.009201 kWh. Total GPU Power : 628.2177623002755 W\n",
870
+ "[codecarbon INFO @ 00:01:44] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n",
871
+ "[codecarbon INFO @ 00:01:44] 0.012664 kWh of electricity used since the beginning.\n",
872
+ "[codecarbon INFO @ 00:01:59] Energy consumed for RAM : 0.003442 kWh. RAM Power : 165.33123922348022 W\n",
873
+ "[codecarbon INFO @ 00:01:59] Energy consumed for all GPUs : 0.011831 kWh. Total GPU Power : 631.8056507544826 W\n",
874
+ "[codecarbon INFO @ 00:01:59] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n",
875
+ "[codecarbon INFO @ 00:01:59] 0.016159 kWh of electricity used since the beginning.\n",
876
+ "[codecarbon INFO @ 00:02:14] Energy consumed for RAM : 0.004130 kWh. RAM Power : 165.33123922348022 W\n",
877
+ "[codecarbon INFO @ 00:02:14] Energy consumed for all GPUs : 0.014450 kWh. Total GPU Power : 629.2086149888297 W\n",
878
+ "[codecarbon INFO @ 00:02:14] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n",
879
+ "[codecarbon INFO @ 00:02:14] 0.019643 kWh of electricity used since the beginning.\n",
880
+ "\n",
881
+ "KeyboardInterrupt\n",
882
+ "\n"
883
+ ]
884
+ }
885
+ ],
886
+ "source": [
887
+ "tracker.start()\n",
888
+ "trainer.train()\n",
889
+ "tracker.stop()\n"
890
+ ]
891
+ },
892
+ {
893
+ "cell_type": "code",
894
+ "execution_count": null,
895
+ "metadata": {
896
+ "gather": {
897
+ "logged": 1706486541918
898
+ }
899
+ },
900
+ "outputs": [],
901
+ "source": [
902
+ "wandb.finish()"
903
+ ]
904
+ },
905
+ {
906
+ "cell_type": "code",
907
+ "execution_count": null,
908
+ "metadata": {
909
+ "gather": {
910
+ "logged": 1706486541928
911
+ }
912
+ },
913
+ "outputs": [],
914
+ "source": [
915
+ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
916
+ "tokenizer._tokenizer.save(\"tokenizer.json\")\n",
917
+ "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
918
+ "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
919
+ "\n",
920
+ "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
921
+ " variant=variant,\n",
922
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
923
+ ]
924
+ }
925
+ ],
926
+ "metadata": {
927
+ "datalore": {
928
+ "base_environment": "default",
929
+ "computation_mode": "JUPYTER",
930
+ "package_manager": "pip",
931
+ "packages": [
932
+ {
933
+ "name": "datasets",
934
+ "source": "PIP",
935
+ "version": "2.16.1"
936
+ },
937
+ {
938
+ "name": "torch",
939
+ "source": "PIP",
940
+ "version": "2.1.2"
941
+ },
942
+ {
943
+ "name": "accelerate",
944
+ "source": "PIP",
945
+ "version": "0.26.1"
946
+ }
947
+ ],
948
+ "report_row_ids": [
949
+ "un8W7ez7ZwoGb5Co6nydEV",
950
+ "40nN9Hvgi1clHNV5RAemI5",
951
+ "TgRD90H5NSPpKS41OeXI1w",
952
+ "ZOm5BfUs3h1EGLaUkBGeEB",
953
+ "kOP0CZWNSk6vqE3wkPp7Vc",
954
+ "W4PWcOu2O2pRaZyoE2W80h",
955
+ "RolbOnQLIftk0vy9mIcz5M",
956
+ "8OPhUgbaNJmOdiq5D3a6vK",
957
+ "5Qrt3jSvSrpK6Ne1hS6shL",
958
+ "hTq7nFUrovN5Ao4u6dIYWZ",
959
+ "I8WNZLpJ1DVP2wiCW7YBIB",
960
+ "SawhU3I9BewSE1XBPstpNJ",
961
+ "80EtLEl2FIE4FqbWnUD3nT"
962
+ ],
963
+ "version": 3
964
+ },
965
+ "kernel_info": {
966
+ "name": "python38-azureml-pt-tf"
967
+ },
968
+ "kernelspec": {
969
+ "display_name": "Python 3.8 - Pytorch and Tensorflow",
970
+ "language": "python",
971
+ "name": "python38-azureml-pt-tf"
972
+ },
973
+ "language_info": {
974
+ "codemirror_mode": {
975
+ "name": "ipython",
976
+ "version": 3
977
+ },
978
+ "file_extension": ".py",
979
+ "mimetype": "text/x-python",
980
+ "name": "python",
981
+ "nbconvert_exporter": "python",
982
+ "pygments_lexer": "ipython3",
983
+ "version": "3.8.5"
984
+ },
985
+ "microsoft": {
986
+ "host": {
987
+ "AzureML": {
988
+ "notebookHasBeenCompleted": true
989
+ }
990
+ },
991
+ "ms_spell_check": {
992
+ "ms_spell_check_language": "en"
993
+ }
994
+ },
995
+ "nteract": {
996
+ "version": "nteract-front-end@1.0.0"
997
+ }
998
+ },
999
+ "nbformat": 4,
1000
+ "nbformat_minor": 4
1001
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-4-40-54Z.ipynb ADDED
@@ -0,0 +1,1073 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
8
+ "\n",
9
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "gather": {
17
+ "logged": 1706475754655
18
+ },
19
+ "nteract": {
20
+ "transient": {
21
+ "deleting": false
22
+ }
23
+ },
24
+ "tags": []
25
+ },
26
+ "outputs": [
27
+ {
28
+ "name": "stdout",
29
+ "output_type": "stream",
30
+ "text": [
31
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
32
+ "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
33
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
34
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
35
+ "Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
36
+ "Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
37
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
38
+ "Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
39
+ "Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
40
+ "Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
41
+ "Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
42
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
43
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
44
+ "Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
45
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
46
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
47
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
48
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
49
+ "Note: you may need to restart the kernel to use updated packages.\n"
50
+ ]
51
+ }
52
+ ],
53
+ "source": [
54
+ "%pip install accelerate -U"
55
+ ]
56
+ },
57
+ {
58
+ "cell_type": "code",
59
+ "execution_count": 2,
60
+ "metadata": {
61
+ "nteract": {
62
+ "transient": {
63
+ "deleting": false
64
+ }
65
+ }
66
+ },
67
+ "outputs": [
68
+ {
69
+ "name": "stdout",
70
+ "output_type": "stream",
71
+ "text": [
72
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
73
+ "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
74
+ "Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
75
+ "Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
76
+ "Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
77
+ "Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
78
+ "Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
79
+ "Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
80
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
81
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
82
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
83
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
84
+ "Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
85
+ "Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
86
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
87
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
88
+ "Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
89
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
90
+ "Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
91
+ "Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
92
+ "Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
93
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
94
+ "Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
95
+ "Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
96
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
97
+ "Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
98
+ "Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
99
+ "Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
100
+ "Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
101
+ "Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
102
+ "Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
103
+ "Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
104
+ "Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
105
+ "Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
106
+ "Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
107
+ "Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
108
+ "Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
109
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
110
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
111
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
112
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
113
+ "Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
114
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
115
+ "Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
116
+ "Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
117
+ "Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
118
+ "Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
119
+ "Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
120
+ "Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
121
+ "Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
122
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
123
+ "Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
124
+ "Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
125
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
126
+ "Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
127
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
128
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
129
+ "Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
130
+ "Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
131
+ "Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
132
+ "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
133
+ "Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
134
+ "Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
135
+ "Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
136
+ "Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
137
+ "Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
138
+ "Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
139
+ "Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
140
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
141
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
142
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
143
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
144
+ "Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
145
+ "Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
146
+ "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
147
+ "Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
148
+ "Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
149
+ "Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
150
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
151
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
152
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
153
+ "Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
154
+ "Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
155
+ "Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
156
+ "Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
157
+ "Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
158
+ "Note: you may need to restart the kernel to use updated packages.\n"
159
+ ]
160
+ }
161
+ ],
162
+ "source": [
163
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
164
+ ]
165
+ },
166
+ {
167
+ "cell_type": "code",
168
+ "execution_count": 4,
169
+ "metadata": {
170
+ "datalore": {
171
+ "hide_input_from_viewers": false,
172
+ "hide_output_from_viewers": false,
173
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
174
+ "report_properties": {
175
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
176
+ },
177
+ "type": "CODE"
178
+ },
179
+ "gather": {
180
+ "logged": 1706486372154
181
+ },
182
+ "tags": []
183
+ },
184
+ "outputs": [
185
+ {
186
+ "name": "stderr",
187
+ "output_type": "stream",
188
+ "text": [
189
+ "[codecarbon INFO @ 04:20:20] [setup] RAM Tracking...\n",
190
+ "[codecarbon INFO @ 04:20:20] [setup] GPU Tracking...\n",
191
+ "[codecarbon INFO @ 04:20:20] Tracking Nvidia GPU via pynvml\n",
192
+ "[codecarbon INFO @ 04:20:20] [setup] CPU Tracking...\n",
193
+ "[codecarbon WARNING @ 04:20:20] No CPU tracking mode found. Falling back on CPU constant mode.\n",
194
+ "[codecarbon WARNING @ 04:20:21] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n",
195
+ "[codecarbon INFO @ 04:20:21] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
196
+ "[codecarbon INFO @ 04:20:21] >>> Tracker's metadata:\n",
197
+ "[codecarbon INFO @ 04:20:21] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n",
198
+ "[codecarbon INFO @ 04:20:21] Python version: 3.8.5\n",
199
+ "[codecarbon INFO @ 04:20:21] CodeCarbon version: 2.3.3\n",
200
+ "[codecarbon INFO @ 04:20:21] Available RAM : 440.883 GB\n",
201
+ "[codecarbon INFO @ 04:20:21] CPU count: 24\n",
202
+ "[codecarbon INFO @ 04:20:21] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
203
+ "[codecarbon INFO @ 04:20:21] GPU count: 4\n",
204
+ "[codecarbon INFO @ 04:20:21] GPU model: 4 x Tesla V100-PCIE-16GB\n",
205
+ "[codecarbon WARNING @ 04:20:21] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
206
+ ]
207
+ }
208
+ ],
209
+ "source": [
210
+ "import pandas as pd\n",
211
+ "import numpy as np\n",
212
+ "import torch\n",
213
+ "import os\n",
214
+ "from typing import List, Union\n",
215
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
216
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
217
+ "import shap\n",
218
+ "import wandb\n",
219
+ "import evaluate\n",
220
+ "from codecarbon import EmissionsTracker\n",
221
+ "import logging\n",
222
+ "\n",
223
+ "wandb.finish()\n",
224
+ "\n",
225
+ "logging.getLogger('codecarbon').propagate = False\n",
226
+ "\n",
227
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
228
+ "tracker = EmissionsTracker()\n",
229
+ "\n",
230
+ "%load_ext watermark"
231
+ ]
232
+ },
233
+ {
234
+ "cell_type": "code",
235
+ "execution_count": 5,
236
+ "metadata": {
237
+ "collapsed": false,
238
+ "gather": {
239
+ "logged": 1706486372304
240
+ },
241
+ "jupyter": {
242
+ "outputs_hidden": false
243
+ }
244
+ },
245
+ "outputs": [],
246
+ "source": [
247
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
248
+ "\n",
249
+ "SEED: int = 42\n",
250
+ "\n",
251
+ "BATCH_SIZE: int = 32\n",
252
+ "EPOCHS: int = 5\n",
253
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
254
+ "\n",
255
+ "# WandB configuration\n",
256
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
257
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
258
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
259
+ ]
260
+ },
261
+ {
262
+ "cell_type": "code",
263
+ "execution_count": 6,
264
+ "metadata": {
265
+ "collapsed": false,
266
+ "jupyter": {
267
+ "outputs_hidden": false
268
+ }
269
+ },
270
+ "outputs": [
271
+ {
272
+ "name": "stdout",
273
+ "output_type": "stream",
274
+ "text": [
275
+ "re : 2.2.1\n",
276
+ "pandas : 2.0.2\n",
277
+ "evaluate: 0.4.1\n",
278
+ "logging : 0.5.1.2\n",
279
+ "torch : 1.12.0\n",
280
+ "shap : 0.44.1\n",
281
+ "wandb : 0.16.2\n",
282
+ "numpy : 1.23.5\n",
283
+ "\n"
284
+ ]
285
+ }
286
+ ],
287
+ "source": [
288
+ "%watermark --iversion"
289
+ ]
290
+ },
291
+ {
292
+ "cell_type": "code",
293
+ "execution_count": 7,
294
+ "metadata": {
295
+ "datalore": {
296
+ "hide_input_from_viewers": true,
297
+ "hide_output_from_viewers": true,
298
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
299
+ "type": "CODE"
300
+ }
301
+ },
302
+ "outputs": [
303
+ {
304
+ "name": "stdout",
305
+ "output_type": "stream",
306
+ "text": [
307
+ "/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
308
+ "Mon Jan 29 04:20:46 2024 \n",
309
+ "+---------------------------------------------------------------------------------------+\n",
310
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
311
+ "|-----------------------------------------+----------------------+----------------------+\n",
312
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
313
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
314
+ "| | | MIG M. |\n",
315
+ "|=========================================+======================+======================|\n",
316
+ "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
317
+ "| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
318
+ "| | | N/A |\n",
319
+ "+-----------------------------------------+----------------------+----------------------+\n",
320
+ "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
321
+ "| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\n",
322
+ "| | | N/A |\n",
323
+ "+-----------------------------------------+----------------------+----------------------+\n",
324
+ "| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\n",
325
+ "| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
326
+ "| | | N/A |\n",
327
+ "+-----------------------------------------+----------------------+----------------------+\n",
328
+ "| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\n",
329
+ "| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
330
+ "| | | N/A |\n",
331
+ "+-----------------------------------------+----------------------+----------------------+\n",
332
+ " \n",
333
+ "+---------------------------------------------------------------------------------------+\n",
334
+ "| Processes: |\n",
335
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
336
+ "| ID ID Usage |\n",
337
+ "|=======================================================================================|\n",
338
+ "| No running processes found |\n",
339
+ "+---------------------------------------------------------------------------------------+\n"
340
+ ]
341
+ }
342
+ ],
343
+ "source": [
344
+ "!nvidia-smi"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "markdown",
349
+ "metadata": {
350
+ "datalore": {
351
+ "hide_input_from_viewers": false,
352
+ "hide_output_from_viewers": false,
353
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
354
+ "report_properties": {
355
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
356
+ },
357
+ "type": "MD"
358
+ }
359
+ },
360
+ "source": [
361
+ "## Loading the data set"
362
+ ]
363
+ },
364
+ {
365
+ "cell_type": "code",
366
+ "execution_count": 8,
367
+ "metadata": {
368
+ "collapsed": false,
369
+ "gather": {
370
+ "logged": 1706486373931
371
+ },
372
+ "jupyter": {
373
+ "outputs_hidden": false
374
+ }
375
+ },
376
+ "outputs": [],
377
+ "source": [
378
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
379
+ ]
380
+ },
381
+ {
382
+ "cell_type": "code",
383
+ "execution_count": 9,
384
+ "metadata": {
385
+ "collapsed": false,
386
+ "gather": {
387
+ "logged": 1706486374218
388
+ },
389
+ "jupyter": {
390
+ "outputs_hidden": false,
391
+ "source_hidden": false
392
+ },
393
+ "nteract": {
394
+ "transient": {
395
+ "deleting": false
396
+ }
397
+ }
398
+ },
399
+ "outputs": [
400
+ {
401
+ "data": {
402
+ "text/plain": [
403
+ "DatasetDict({\n",
404
+ " train: Dataset({\n",
405
+ " features: ['id', 'text', 'label'],\n",
406
+ " num_rows: 1270444\n",
407
+ " })\n",
408
+ " test: Dataset({\n",
409
+ " features: ['id', 'text', 'label'],\n",
410
+ " num_rows: 272238\n",
411
+ " })\n",
412
+ " val: Dataset({\n",
413
+ " features: ['id', 'text', 'label'],\n",
414
+ " num_rows: 272238\n",
415
+ " })\n",
416
+ "})"
417
+ ]
418
+ },
419
+ "execution_count": 9,
420
+ "metadata": {},
421
+ "output_type": "execute_result"
422
+ }
423
+ ],
424
+ "source": [
425
+ "dataset"
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": 10,
431
+ "metadata": {
432
+ "gather": {
433
+ "logged": 1706486374480
434
+ }
435
+ },
436
+ "outputs": [],
437
+ "source": [
438
+ "SUBSAMPLING = 1.0\n",
439
+ "\n",
440
+ "if SUBSAMPLING < 1:\n",
441
+ " _ = DatasetDict()\n",
442
+ " for each in dataset.keys():\n",
443
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
444
+ "\n",
445
+ " dataset = _"
446
+ ]
447
+ },
448
+ {
449
+ "cell_type": "markdown",
450
+ "metadata": {},
451
+ "source": [
452
+ "## Tokenisation and encoding"
453
+ ]
454
+ },
455
+ {
456
+ "cell_type": "code",
457
+ "execution_count": 11,
458
+ "metadata": {
459
+ "gather": {
460
+ "logged": 1706486375030
461
+ }
462
+ },
463
+ "outputs": [],
464
+ "source": [
465
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
466
+ " return ds_enc"
467
+ ]
468
+ },
469
+ {
470
+ "cell_type": "markdown",
471
+ "metadata": {},
472
+ "source": [
473
+ "## Evaluation metrics"
474
+ ]
475
+ },
476
+ {
477
+ "cell_type": "code",
478
+ "execution_count": 12,
479
+ "metadata": {
480
+ "gather": {
481
+ "logged": 1706486375197
482
+ }
483
+ },
484
+ "outputs": [],
485
+ "source": [
486
+ "accuracy = evaluate.load(\"accuracy\")\n",
487
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
488
+ "f1 = evaluate.load(\"f1\")"
489
+ ]
490
+ },
491
+ {
492
+ "cell_type": "code",
493
+ "execution_count": 13,
494
+ "metadata": {
495
+ "gather": {
496
+ "logged": 1706486375361
497
+ }
498
+ },
499
+ "outputs": [],
500
+ "source": [
501
+ "def compute_metrics(eval_pred):\n",
502
+ " predictions, labels = eval_pred\n",
503
+ " predictions = np.argmax(predictions, axis=1)\n",
504
+ " return {\n",
505
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
506
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
507
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
508
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
509
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
510
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
511
+ " }"
512
+ ]
513
+ },
514
+ {
515
+ "cell_type": "markdown",
516
+ "metadata": {},
517
+ "source": [
518
+ "## Training"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "markdown",
523
+ "metadata": {},
524
+ "source": [
525
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
526
+ ]
527
+ },
528
+ {
529
+ "cell_type": "code",
530
+ "execution_count": 14,
531
+ "metadata": {
532
+ "gather": {
533
+ "logged": 1706486375569
534
+ }
535
+ },
536
+ "outputs": [],
537
+ "source": [
538
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
539
+ ]
540
+ },
541
+ {
542
+ "cell_type": "code",
543
+ "execution_count": 15,
544
+ "metadata": {
545
+ "gather": {
546
+ "logged": 1706486433708
547
+ }
548
+ },
549
+ "outputs": [
550
+ {
551
+ "name": "stderr",
552
+ "output_type": "stream",
553
+ "text": [
554
+ "Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1270444/1270444 [08:09<00:00, 2595.90 examples/s]\n",
555
+ "Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272238/272238 [01:45<00:00, 2585.25 examples/s]\n",
556
+ "Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████���█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272238/272238 [01:44<00:00, 2605.66 examples/s]\n"
557
+ ]
558
+ }
559
+ ],
560
+ "source": [
561
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
562
+ "\n",
563
+ "cols = dataset[\"train\"].column_names\n",
564
+ "cols.remove(\"label\")\n",
565
+ "ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n"
566
+ ]
567
+ },
568
+ {
569
+ "cell_type": "code",
570
+ "execution_count": 16,
571
+ "metadata": {},
572
+ "outputs": [
573
+ {
574
+ "name": "stderr",
575
+ "output_type": "stream",
576
+ "text": [
577
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
578
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
579
+ ]
580
+ }
581
+ ],
582
+ "source": [
583
+ "\n",
584
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
585
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
586
+ " id2label=label_map, \n",
587
+ " label2id={v:k for k,v in label_map.items()})\n",
588
+ "\n",
589
+ "args = TrainingArguments(\n",
590
+ " output_dir=\"vaers\",\n",
591
+ " evaluation_strategy=\"epoch\",\n",
592
+ " save_strategy=\"epoch\",\n",
593
+ " learning_rate=2e-5,\n",
594
+ " per_device_train_batch_size=BATCH_SIZE,\n",
595
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
596
+ " num_train_epochs=EPOCHS,\n",
597
+ " weight_decay=.01,\n",
598
+ " logging_steps=1,\n",
599
+ " load_best_model_at_end=True,\n",
600
+ " run_name=f\"daedra-training\",\n",
601
+ " report_to=[\"wandb\"])\n",
602
+ "\n",
603
+ "trainer = Trainer(\n",
604
+ " model=model,\n",
605
+ " args=args,\n",
606
+ " train_dataset=ds_enc[\"train\"],\n",
607
+ " eval_dataset=ds_enc[\"test\"],\n",
608
+ " tokenizer=tokenizer,\n",
609
+ " compute_metrics=compute_metrics)"
610
+ ]
611
+ },
612
+ {
613
+ "cell_type": "code",
614
+ "execution_count": 17,
615
+ "metadata": {
616
+ "gather": {
617
+ "logged": 1706486444806
618
+ }
619
+ },
620
+ "outputs": [
621
+ {
622
+ "name": "stderr",
623
+ "output_type": "stream",
624
+ "text": [
625
+ "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
626
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
627
+ ]
628
+ },
629
+ {
630
+ "data": {
631
+ "text/html": [
632
+ "Tracking run with wandb version 0.16.2"
633
+ ],
634
+ "text/plain": [
635
+ "<IPython.core.display.HTML object>"
636
+ ]
637
+ },
638
+ "metadata": {},
639
+ "output_type": "display_data"
640
+ },
641
+ {
642
+ "data": {
643
+ "text/html": [
644
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_043232-tl59png2</code>"
645
+ ],
646
+ "text/plain": [
647
+ "<IPython.core.display.HTML object>"
648
+ ]
649
+ },
650
+ "metadata": {},
651
+ "output_type": "display_data"
652
+ },
653
+ {
654
+ "data": {
655
+ "text/html": [
656
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
657
+ ],
658
+ "text/plain": [
659
+ "<IPython.core.display.HTML object>"
660
+ ]
661
+ },
662
+ "metadata": {},
663
+ "output_type": "display_data"
664
+ },
665
+ {
666
+ "data": {
667
+ "text/html": [
668
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
669
+ ],
670
+ "text/plain": [
671
+ "<IPython.core.display.HTML object>"
672
+ ]
673
+ },
674
+ "metadata": {},
675
+ "output_type": "display_data"
676
+ },
677
+ {
678
+ "data": {
679
+ "text/html": [
680
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2</a>"
681
+ ],
682
+ "text/plain": [
683
+ "<IPython.core.display.HTML object>"
684
+ ]
685
+ },
686
+ "metadata": {},
687
+ "output_type": "display_data"
688
+ },
689
+ {
690
+ "data": {
691
+ "text/html": [
692
+ "Finishing last run (ID:tl59png2) before initializing another..."
693
+ ],
694
+ "text/plain": [
695
+ "<IPython.core.display.HTML object>"
696
+ ]
697
+ },
698
+ "metadata": {},
699
+ "output_type": "display_data"
700
+ },
701
+ {
702
+ "data": {
703
+ "text/html": [
704
+ " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v0' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v0</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
705
+ ],
706
+ "text/plain": [
707
+ "<IPython.core.display.HTML object>"
708
+ ]
709
+ },
710
+ "metadata": {},
711
+ "output_type": "display_data"
712
+ },
713
+ {
714
+ "data": {
715
+ "text/html": [
716
+ "Find logs at: <code>./wandb/run-20240129_043232-tl59png2/logs</code>"
717
+ ],
718
+ "text/plain": [
719
+ "<IPython.core.display.HTML object>"
720
+ ]
721
+ },
722
+ "metadata": {},
723
+ "output_type": "display_data"
724
+ },
725
+ {
726
+ "data": {
727
+ "text/html": [
728
+ "Successfully finished last run (ID:tl59png2). Initializing new run:<br/>"
729
+ ],
730
+ "text/plain": [
731
+ "<IPython.core.display.HTML object>"
732
+ ]
733
+ },
734
+ "metadata": {},
735
+ "output_type": "display_data"
736
+ },
737
+ {
738
+ "data": {
739
+ "text/html": [
740
+ "Tracking run with wandb version 0.16.2"
741
+ ],
742
+ "text/plain": [
743
+ "<IPython.core.display.HTML object>"
744
+ ]
745
+ },
746
+ "metadata": {},
747
+ "output_type": "display_data"
748
+ },
749
+ {
750
+ "data": {
751
+ "text/html": [
752
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_043243-x8j2xw0x</code>"
753
+ ],
754
+ "text/plain": [
755
+ "<IPython.core.display.HTML object>"
756
+ ]
757
+ },
758
+ "metadata": {},
759
+ "output_type": "display_data"
760
+ },
761
+ {
762
+ "data": {
763
+ "text/html": [
764
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
765
+ ],
766
+ "text/plain": [
767
+ "<IPython.core.display.HTML object>"
768
+ ]
769
+ },
770
+ "metadata": {},
771
+ "output_type": "display_data"
772
+ },
773
+ {
774
+ "data": {
775
+ "text/html": [
776
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
777
+ ],
778
+ "text/plain": [
779
+ "<IPython.core.display.HTML object>"
780
+ ]
781
+ },
782
+ "metadata": {},
783
+ "output_type": "display_data"
784
+ },
785
+ {
786
+ "data": {
787
+ "text/html": [
788
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x</a>"
789
+ ],
790
+ "text/plain": [
791
+ "<IPython.core.display.HTML object>"
792
+ ]
793
+ },
794
+ "metadata": {},
795
+ "output_type": "display_data"
796
+ },
797
+ {
798
+ "data": {
799
+ "text/html": [
800
+ "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
801
+ ],
802
+ "text/plain": [
803
+ "<wandb.sdk.wandb_run.Run at 0x7ffa3d0e9bb0>"
804
+ ]
805
+ },
806
+ "execution_count": 17,
807
+ "metadata": {},
808
+ "output_type": "execute_result"
809
+ }
810
+ ],
811
+ "source": [
812
+ "if SUBSAMPLING != 1.0:\n",
813
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
814
+ "else:\n",
815
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
816
+ "\n",
817
+ "wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
818
+ "wandb_tag.append(f\"base:{model_ckpt}\")\n",
819
+ " \n",
820
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
821
+ ]
822
+ },
823
+ {
824
+ "cell_type": "code",
825
+ "execution_count": 18,
826
+ "metadata": {
827
+ "gather": {
828
+ "logged": 1706486541798
829
+ }
830
+ },
831
+ "outputs": [
832
+ {
833
+ "name": "stderr",
834
+ "output_type": "stream",
835
+ "text": [
836
+ "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
837
+ ]
838
+ },
839
+ {
840
+ "data": {
841
+ "text/html": [
842
+ "\n",
843
+ " <div>\n",
844
+ " \n",
845
+ " <progress value='394' max='49630' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
846
+ " [ 394/49630 04:13 < 8:50:52, 1.55 it/s, Epoch 0.04/5]\n",
847
+ " </div>\n",
848
+ " <table border=\"1\" class=\"dataframe\">\n",
849
+ " <thead>\n",
850
+ " <tr style=\"text-align: left;\">\n",
851
+ " <th>Epoch</th>\n",
852
+ " <th>Training Loss</th>\n",
853
+ " <th>Validation Loss</th>\n",
854
+ " </tr>\n",
855
+ " </thead>\n",
856
+ " <tbody>\n",
857
+ " </tbody>\n",
858
+ "</table><p>"
859
+ ],
860
+ "text/plain": [
861
+ "<IPython.core.display.HTML object>"
862
+ ]
863
+ },
864
+ "metadata": {},
865
+ "output_type": "display_data"
866
+ },
867
+ {
868
+ "name": "stderr",
869
+ "output_type": "stream",
870
+ "text": [
871
+ "[codecarbon INFO @ 04:33:12] Energy consumed for RAM : 0.000689 kWh. RAM Power : 165.33123922348022 W\n",
872
+ "[codecarbon INFO @ 04:33:12] Energy consumed for all GPUs : 0.001450 kWh. Total GPU Power : 347.66451200921796 W\n",
873
+ "[codecarbon INFO @ 04:33:12] Energy consumed for all CPUs : 0.000177 kWh. Total CPU Power : 42.5 W\n",
874
+ "[codecarbon INFO @ 04:33:12] 0.002317 kWh of electricity used since the beginning.\n",
875
+ "[codecarbon INFO @ 04:33:27] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n",
876
+ "[codecarbon INFO @ 04:33:27] Energy consumed for all GPUs : 0.004012 kWh. Total GPU Power : 615.4556826768763 W\n",
877
+ "[codecarbon INFO @ 04:33:27] Energy consumed for all CPUs : 0.000355 kWh. Total CPU Power : 42.5 W\n",
878
+ "[codecarbon INFO @ 04:33:27] 0.005745 kWh of electricity used since the beginning.\n",
879
+ "[codecarbon INFO @ 04:33:42] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n",
880
+ "[codecarbon INFO @ 04:33:42] Energy consumed for all GPUs : 0.006596 kWh. Total GPU Power : 620.9110211178034 W\n",
881
+ "[codecarbon INFO @ 04:33:42] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n",
882
+ "[codecarbon INFO @ 04:33:42] 0.009194 kWh of electricity used since the beginning.\n",
883
+ "[codecarbon INFO @ 04:33:57] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n",
884
+ "[codecarbon INFO @ 04:33:57] Energy consumed for all GPUs : 0.009183 kWh. Total GPU Power : 621.1270289526989 W\n",
885
+ "[codecarbon INFO @ 04:33:57] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n",
886
+ "[codecarbon INFO @ 04:33:57] 0.012645 kWh of electricity used since the beginning.\n",
887
+ "[codecarbon INFO @ 04:34:12] Energy consumed for RAM : 0.003442 kWh. RAM Power : 165.33123922348022 W\n",
888
+ "[codecarbon INFO @ 04:34:12] Energy consumed for all GPUs : 0.011798 kWh. Total GPU Power : 628.3875606622404 W\n",
889
+ "[codecarbon INFO @ 04:34:12] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n",
890
+ "[codecarbon INFO @ 04:34:12] 0.016125 kWh of electricity used since the beginning.\n",
891
+ "[codecarbon INFO @ 04:34:27] Energy consumed for RAM : 0.004130 kWh. RAM Power : 165.33123922348022 W\n",
892
+ "[codecarbon INFO @ 04:34:27] Energy consumed for all GPUs : 0.014431 kWh. Total GPU Power : 632.4054645127197 W\n",
893
+ "[codecarbon INFO @ 04:34:27] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n",
894
+ "[codecarbon INFO @ 04:34:27] 0.019623 kWh of electricity used since the beginning.\n",
895
+ "[codecarbon INFO @ 04:34:42] Energy consumed for RAM : 0.004818 kWh. RAM Power : 165.33123922348022 W\n",
896
+ "[codecarbon INFO @ 04:34:42] Energy consumed for all GPUs : 0.017064 kWh. Total GPU Power : 632.6571124342939 W\n",
897
+ "[codecarbon INFO @ 04:34:42] Energy consumed for all CPUs : 0.001240 kWh. Total CPU Power : 42.5 W\n",
898
+ "[codecarbon INFO @ 04:34:42] 0.023122 kWh of electricity used since the beginning.\n",
899
+ "[codecarbon INFO @ 04:34:57] Energy consumed for RAM : 0.005506 kWh. RAM Power : 165.33123922348022 W\n",
900
+ "[codecarbon INFO @ 04:34:57] Energy consumed for all GPUs : 0.019707 kWh. Total GPU Power : 634.7921879339333 W\n",
901
+ "[codecarbon INFO @ 04:34:57] Energy consumed for all CPUs : 0.001417 kWh. Total CPU Power : 42.5 W\n",
902
+ "[codecarbon INFO @ 04:34:57] 0.026631 kWh of electricity used since the beginning.\n",
903
+ "[codecarbon INFO @ 04:35:12] Energy consumed for RAM : 0.006194 kWh. RAM Power : 165.33123922348022 W\n",
904
+ "[codecarbon INFO @ 04:35:12] Energy consumed for all GPUs : 0.022334 kWh. Total GPU Power : 630.3609394863598 W\n",
905
+ "[codecarbon INFO @ 04:35:12] Energy consumed for all CPUs : 0.001594 kWh. Total CPU Power : 42.5 W\n",
906
+ "[codecarbon INFO @ 04:35:12] 0.030123 kWh of electricity used since the beginning.\n",
907
+ "[codecarbon INFO @ 04:35:27] Energy consumed for RAM : 0.006882 kWh. RAM Power : 165.33123922348022 W\n",
908
+ "[codecarbon INFO @ 04:35:27] Energy consumed for all GPUs : 0.024956 kWh. Total GPU Power : 630.704729336156 W\n",
909
+ "[codecarbon INFO @ 04:35:27] Energy consumed for all CPUs : 0.001771 kWh. Total CPU Power : 42.5 W\n",
910
+ "[codecarbon INFO @ 04:35:27] 0.033609 kWh of electricity used since the beginning.\n",
911
+ "[codecarbon INFO @ 04:35:42] Energy consumed for RAM : 0.007570 kWh. RAM Power : 165.33123922348022 W\n",
912
+ "[codecarbon INFO @ 04:35:42] Energy consumed for all GPUs : 0.027604 kWh. Total GPU Power : 636.1545465788125 W\n",
913
+ "[codecarbon INFO @ 04:35:42] Energy consumed for all CPUs : 0.001948 kWh. Total CPU Power : 42.5 W\n",
914
+ "[codecarbon INFO @ 04:35:42] 0.037121 kWh of electricity used since the beginning.\n",
915
+ "[codecarbon INFO @ 04:35:57] Energy consumed for RAM : 0.008258 kWh. RAM Power : 165.33123922348022 W\n",
916
+ "[codecarbon INFO @ 04:35:57] Energy consumed for all GPUs : 0.030255 kWh. Total GPU Power : 636.9769106141198 W\n",
917
+ "[codecarbon INFO @ 04:35:57] Energy consumed for all CPUs : 0.002125 kWh. Total CPU Power : 42.5 W\n",
918
+ "[codecarbon INFO @ 04:35:57] 0.040638 kWh of electricity used since the beginning.\n",
919
+ "[codecarbon INFO @ 04:36:12] Energy consumed for RAM : 0.008946 kWh. RAM Power : 165.33123922348022 W\n",
920
+ "[codecarbon INFO @ 04:36:12] Energy consumed for all GPUs : 0.032913 kWh. Total GPU Power : 638.3412890613937 W\n",
921
+ "[codecarbon INFO @ 04:36:12] Energy consumed for all CPUs : 0.002302 kWh. Total CPU Power : 42.5 W\n",
922
+ "[codecarbon INFO @ 04:36:12] 0.044161 kWh of electricity used since the beginning.\n",
923
+ "[codecarbon INFO @ 04:36:27] Energy consumed for RAM : 0.009634 kWh. RAM Power : 165.33123922348022 W\n",
924
+ "[codecarbon INFO @ 04:36:27] Energy consumed for all GPUs : 0.035515 kWh. Total GPU Power : 625.0502398771333 W\n",
925
+ "[codecarbon INFO @ 04:36:27] Energy consumed for all CPUs : 0.002479 kWh. Total CPU Power : 42.5 W\n",
926
+ "[codecarbon INFO @ 04:36:27] 0.047628 kWh of electricity used since the beginning.\n",
927
+ "[codecarbon INFO @ 04:36:42] Energy consumed for RAM : 0.010322 kWh. RAM Power : 165.33123922348022 W\n",
928
+ "[codecarbon INFO @ 04:36:42] Energy consumed for all GPUs : 0.038183 kWh. Total GPU Power : 641.00719087638 W\n",
929
+ "[codecarbon INFO @ 04:36:42] Energy consumed for all CPUs : 0.002656 kWh. Total CPU Power : 42.5 W\n",
930
+ "[codecarbon INFO @ 04:36:42] 0.051162 kWh of electricity used since the beginning.\n",
931
+ "[codecarbon INFO @ 04:36:57] Energy consumed for RAM : 0.011010 kWh. RAM Power : 165.33123922348022 W\n",
932
+ "[codecarbon INFO @ 04:36:57] Energy consumed for all GPUs : 0.040821 kWh. Total GPU Power : 633.4817689949092 W\n",
933
+ "[codecarbon INFO @ 04:36:57] Energy consumed for all CPUs : 0.002834 kWh. Total CPU Power : 42.5 W\n",
934
+ "[codecarbon INFO @ 04:36:57] 0.054665 kWh of electricity used since the beginning.\n",
935
+ "[codecarbon INFO @ 04:37:12] Energy consumed for RAM : 0.011698 kWh. RAM Power : 165.33123922348022 W\n",
936
+ "[codecarbon INFO @ 04:37:12] Energy consumed for all GPUs : 0.043484 kWh. Total GPU Power : 639.8452880027475 W\n",
937
+ "[codecarbon INFO @ 04:37:12] Energy consumed for all CPUs : 0.003011 kWh. Total CPU Power : 42.5 W\n",
938
+ "[codecarbon INFO @ 04:37:12] 0.058193 kWh of electricity used since the beginning.\n"
939
+ ]
940
+ }
941
+ ],
942
+ "source": [
943
+ "tracker.start()\n",
944
+ "trainer.train()\n",
945
+ "tracker.stop()\n"
946
+ ]
947
+ },
948
+ {
949
+ "cell_type": "code",
950
+ "execution_count": null,
951
+ "metadata": {
952
+ "gather": {
953
+ "logged": 1706486541918
954
+ }
955
+ },
956
+ "outputs": [],
957
+ "source": [
958
+ "wandb.finish()"
959
+ ]
960
+ },
961
+ {
962
+ "cell_type": "code",
963
+ "execution_count": null,
964
+ "metadata": {
965
+ "gather": {
966
+ "logged": 1706486541928
967
+ }
968
+ },
969
+ "outputs": [],
970
+ "source": [
971
+ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
972
+ "tokenizer._tokenizer.save(\"tokenizer.json\")\n",
973
+ "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
974
+ "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
975
+ "\n",
976
+ "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
977
+ " variant=variant,\n",
978
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
979
+ ]
980
+ },
981
+ {
982
+ "cell_type": "code",
983
+ "execution_count": null,
984
+ "metadata": {},
985
+ "outputs": [],
986
+ "source": [
987
+ "variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
988
+ "tokenizer._tokenizer.save(\"tokenizer.json\")\n",
989
+ "tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
990
+ "sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
991
+ "\n",
992
+ "model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
993
+ " variant=variant,\n",
994
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
995
+ ]
996
+ }
997
+ ],
998
+ "metadata": {
999
+ "datalore": {
1000
+ "base_environment": "default",
1001
+ "computation_mode": "JUPYTER",
1002
+ "package_manager": "pip",
1003
+ "packages": [
1004
+ {
1005
+ "name": "datasets",
1006
+ "source": "PIP",
1007
+ "version": "2.16.1"
1008
+ },
1009
+ {
1010
+ "name": "torch",
1011
+ "source": "PIP",
1012
+ "version": "2.1.2"
1013
+ },
1014
+ {
1015
+ "name": "accelerate",
1016
+ "source": "PIP",
1017
+ "version": "0.26.1"
1018
+ }
1019
+ ],
1020
+ "report_row_ids": [
1021
+ "un8W7ez7ZwoGb5Co6nydEV",
1022
+ "40nN9Hvgi1clHNV5RAemI5",
1023
+ "TgRD90H5NSPpKS41OeXI1w",
1024
+ "ZOm5BfUs3h1EGLaUkBGeEB",
1025
+ "kOP0CZWNSk6vqE3wkPp7Vc",
1026
+ "W4PWcOu2O2pRaZyoE2W80h",
1027
+ "RolbOnQLIftk0vy9mIcz5M",
1028
+ "8OPhUgbaNJmOdiq5D3a6vK",
1029
+ "5Qrt3jSvSrpK6Ne1hS6shL",
1030
+ "hTq7nFUrovN5Ao4u6dIYWZ",
1031
+ "I8WNZLpJ1DVP2wiCW7YBIB",
1032
+ "SawhU3I9BewSE1XBPstpNJ",
1033
+ "80EtLEl2FIE4FqbWnUD3nT"
1034
+ ],
1035
+ "version": 3
1036
+ },
1037
+ "kernel_info": {
1038
+ "name": "python38-azureml-pt-tf"
1039
+ },
1040
+ "kernelspec": {
1041
+ "display_name": "azureml_py38_PT_TF",
1042
+ "language": "python",
1043
+ "name": "python3"
1044
+ },
1045
+ "language_info": {
1046
+ "codemirror_mode": {
1047
+ "name": "ipython",
1048
+ "version": 3
1049
+ },
1050
+ "file_extension": ".py",
1051
+ "mimetype": "text/x-python",
1052
+ "name": "python",
1053
+ "nbconvert_exporter": "python",
1054
+ "pygments_lexer": "ipython3",
1055
+ "version": "3.8.5"
1056
+ },
1057
+ "microsoft": {
1058
+ "host": {
1059
+ "AzureML": {
1060
+ "notebookHasBeenCompleted": true
1061
+ }
1062
+ },
1063
+ "ms_spell_check": {
1064
+ "ms_spell_check_language": "en"
1065
+ }
1066
+ },
1067
+ "nteract": {
1068
+ "version": "nteract-front-end@1.0.0"
1069
+ }
1070
+ },
1071
+ "nbformat": 4,
1072
+ "nbformat_minor": 4
1073
+ }
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-30-21-44-8Z.ipynb ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {}
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "source": [
15
+ "%pip install accelerate -U"
16
+ ],
17
+ "outputs": [
18
+ {
19
+ "output_type": "stream",
20
+ "name": "stdout",
21
+ "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
22
+ }
23
+ ],
24
+ "execution_count": 1,
25
+ "metadata": {
26
+ "gather": {
27
+ "logged": 1706475754655
28
+ },
29
+ "nteract": {
30
+ "transient": {
31
+ "deleting": false
32
+ }
33
+ },
34
+ "tags": []
35
+ }
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "source": [
40
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
41
+ ],
42
+ "outputs": [
43
+ {
44
+ "output_type": "stream",
45
+ "name": "stdout",
46
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
47
+ }
48
+ ],
49
+ "execution_count": 2,
50
+ "metadata": {
51
+ "nteract": {
52
+ "transient": {
53
+ "deleting": false
54
+ }
55
+ }
56
+ }
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "import pandas as pd\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "import os\n",
65
+ "from typing import List, Union\n",
66
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
67
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
68
+ "import shap\n",
69
+ "import wandb\n",
70
+ "import evaluate\n",
71
+ "import logging\n",
72
+ "\n",
73
+ "wandb.finish()\n",
74
+ "\n",
75
+ "\n",
76
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
77
+ "\n",
78
+ "%load_ext watermark"
79
+ ],
80
+ "outputs": [
81
+ {
82
+ "output_type": "stream",
83
+ "name": "stderr",
84
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 17:46:15.020290: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 17:46:16.031641: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031779: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
85
+ }
86
+ ],
87
+ "execution_count": 3,
88
+ "metadata": {
89
+ "datalore": {
90
+ "hide_input_from_viewers": false,
91
+ "hide_output_from_viewers": false,
92
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
93
+ "report_properties": {
94
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
95
+ },
96
+ "type": "CODE"
97
+ },
98
+ "gather": {
99
+ "logged": 1706550378660
100
+ },
101
+ "tags": []
102
+ }
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "source": [
107
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
108
+ "\n",
109
+ "SEED: int = 42\n",
110
+ "\n",
111
+ "BATCH_SIZE: int = 32\n",
112
+ "EPOCHS: int = 5\n",
113
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
114
+ "\n",
115
+ "# WandB configuration\n",
116
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
117
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
118
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
119
+ ],
120
+ "outputs": [],
121
+ "execution_count": 4,
122
+ "metadata": {
123
+ "collapsed": false,
124
+ "gather": {
125
+ "logged": 1706550378812
126
+ },
127
+ "jupyter": {
128
+ "outputs_hidden": false
129
+ }
130
+ }
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "source": [
135
+ "%watermark --iversion"
136
+ ],
137
+ "outputs": [
138
+ {
139
+ "output_type": "stream",
140
+ "name": "stdout",
141
+ "text": "shap : 0.44.1\npandas : 2.0.2\nwandb : 0.16.2\nre : 2.2.1\nevaluate: 0.4.1\ntorch : 1.12.0\nnumpy : 1.23.5\nlogging : 0.5.1.2\n\n"
142
+ }
143
+ ],
144
+ "execution_count": 5,
145
+ "metadata": {
146
+ "collapsed": false,
147
+ "jupyter": {
148
+ "outputs_hidden": false
149
+ }
150
+ }
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "source": [
155
+ "!nvidia-smi"
156
+ ],
157
+ "outputs": [
158
+ {
159
+ "output_type": "stream",
160
+ "name": "stdout",
161
+ "text": "Mon Jan 29 17:46:18 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
162
+ }
163
+ ],
164
+ "execution_count": 6,
165
+ "metadata": {
166
+ "datalore": {
167
+ "hide_input_from_viewers": true,
168
+ "hide_output_from_viewers": true,
169
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
170
+ "type": "CODE"
171
+ }
172
+ }
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "source": [
177
+ "## Loading the data set"
178
+ ],
179
+ "metadata": {
180
+ "datalore": {
181
+ "hide_input_from_viewers": false,
182
+ "hide_output_from_viewers": false,
183
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
184
+ "report_properties": {
185
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
186
+ },
187
+ "type": "MD"
188
+ }
189
+ }
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "source": [
194
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
195
+ ],
196
+ "outputs": [],
197
+ "execution_count": 7,
198
+ "metadata": {
199
+ "collapsed": false,
200
+ "gather": {
201
+ "logged": 1706550381141
202
+ },
203
+ "jupyter": {
204
+ "outputs_hidden": false
205
+ }
206
+ }
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "source": [
211
+ "dataset"
212
+ ],
213
+ "outputs": [
214
+ {
215
+ "output_type": "execute_result",
216
+ "execution_count": 8,
217
+ "data": {
218
+ "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
219
+ },
220
+ "metadata": {}
221
+ }
222
+ ],
223
+ "execution_count": 8,
224
+ "metadata": {
225
+ "collapsed": false,
226
+ "gather": {
227
+ "logged": 1706550381303
228
+ },
229
+ "jupyter": {
230
+ "outputs_hidden": false,
231
+ "source_hidden": false
232
+ },
233
+ "nteract": {
234
+ "transient": {
235
+ "deleting": false
236
+ }
237
+ }
238
+ }
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "source": [
243
+ "SUBSAMPLING = 0.01\n",
244
+ "\n",
245
+ "if SUBSAMPLING < 1:\n",
246
+ " _ = DatasetDict()\n",
247
+ " for each in dataset.keys():\n",
248
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
249
+ "\n",
250
+ " dataset = _"
251
+ ],
252
+ "outputs": [],
253
+ "execution_count": 9,
254
+ "metadata": {
255
+ "gather": {
256
+ "logged": 1706550381472
257
+ }
258
+ }
259
+ },
260
+ {
261
+ "cell_type": "markdown",
262
+ "source": [
263
+ "## Tokenisation and encoding"
264
+ ],
265
+ "metadata": {}
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "source": [
270
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
271
+ " return ds_enc"
272
+ ],
273
+ "outputs": [],
274
+ "execution_count": 10,
275
+ "metadata": {
276
+ "gather": {
277
+ "logged": 1706550381637
278
+ }
279
+ }
280
+ },
281
+ {
282
+ "cell_type": "markdown",
283
+ "source": [
284
+ "## Evaluation metrics"
285
+ ],
286
+ "metadata": {}
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "source": [
291
+ "accuracy = evaluate.load(\"accuracy\")\n",
292
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
293
+ "f1 = evaluate.load(\"f1\")"
294
+ ],
295
+ "outputs": [],
296
+ "execution_count": 11,
297
+ "metadata": {
298
+ "gather": {
299
+ "logged": 1706550381778
300
+ }
301
+ }
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "source": [
306
+ "def compute_metrics(eval_pred):\n",
307
+ " predictions, labels = eval_pred\n",
308
+ " predictions = np.argmax(predictions, axis=1)\n",
309
+ " return {\n",
310
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
311
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
312
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
313
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
314
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
315
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
316
+ " }"
317
+ ],
318
+ "outputs": [],
319
+ "execution_count": 12,
320
+ "metadata": {
321
+ "gather": {
322
+ "logged": 1706550381891
323
+ }
324
+ }
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "source": [
329
+ "## Training"
330
+ ],
331
+ "metadata": {}
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "source": [
336
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
337
+ ],
338
+ "metadata": {}
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "source": [
343
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
344
+ ],
345
+ "outputs": [],
346
+ "execution_count": 13,
347
+ "metadata": {
348
+ "gather": {
349
+ "logged": 1706550382032
350
+ }
351
+ }
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "source": [
356
+ "def train_from_model(model_ckpt: str, push: bool = False):\n",
357
+ " print(f\"Initialising training based on {model_ckpt}...\")\n",
358
+ "\n",
359
+ " print(\"Tokenising...\")\n",
360
+ " tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
361
+ "\n",
362
+ " cols = dataset[\"train\"].column_names\n",
363
+ " cols.remove(\"label\")\n",
364
+ " ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
365
+ "\n",
366
+ " print(\"Loading model...\")\n",
367
+ " try:\n",
368
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
369
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
370
+ " id2label=label_map, \n",
371
+ " label2id={v:k for k,v in label_map.items()})\n",
372
+ " except OSError:\n",
373
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
374
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
375
+ " id2label=label_map, \n",
376
+ " label2id={v:k for k,v in label_map.items()},\n",
377
+ " from_tf=True)\n",
378
+ "\n",
379
+ "\n",
380
+ " args = TrainingArguments(\n",
381
+ " output_dir=\"vaers\",\n",
382
+ " evaluation_strategy=\"epoch\",\n",
383
+ " save_strategy=\"epoch\",\n",
384
+ " learning_rate=2e-5,\n",
385
+ " per_device_train_batch_size=BATCH_SIZE,\n",
386
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
387
+ " num_train_epochs=EPOCHS,\n",
388
+ " weight_decay=.01,\n",
389
+ " logging_steps=1,\n",
390
+ " load_best_model_at_end=True,\n",
391
+ " run_name=f\"daedra-training\",\n",
392
+ " report_to=[\"wandb\"])\n",
393
+ "\n",
394
+ " trainer = Trainer(\n",
395
+ " model=model,\n",
396
+ " args=args,\n",
397
+ " train_dataset=ds_enc[\"train\"],\n",
398
+ " eval_dataset=ds_enc[\"test\"],\n",
399
+ " tokenizer=tokenizer,\n",
400
+ " compute_metrics=compute_metrics)\n",
401
+ " \n",
402
+ " if SUBSAMPLING != 1.0:\n",
403
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
404
+ " else:\n",
405
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
406
+ "\n",
407
+ " wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
408
+ " wandb_tag.append(f\"base:{model_ckpt}\")\n",
409
+ " \n",
410
+ " wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
411
+ "\n",
412
+ " print(\"Starting training...\")\n",
413
+ "\n",
414
+ " trainer.train()\n",
415
+ "\n",
416
+ " print(\"Training finished.\")\n",
417
+ "\n",
418
+ " if push:\n",
419
+ " variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
420
+ " tokenizer._tokenizer.save(\"tokenizer.json\")\n",
421
+ " tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
422
+ " sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
423
+ "\n",
424
+ " model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
425
+ " variant=variant,\n",
426
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
427
+ ],
428
+ "outputs": [],
429
+ "execution_count": 14,
430
+ "metadata": {
431
+ "jupyter": {
432
+ "outputs_hidden": false,
433
+ "source_hidden": false
434
+ },
435
+ "nteract": {
436
+ "transient": {
437
+ "deleting": false
438
+ }
439
+ },
440
+ "gather": {
441
+ "logged": 1706550382160
442
+ }
443
+ }
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "source": [
448
+ "\n",
449
+ "base_models = [\n",
450
+ " \"bert-base-uncased\",\n",
451
+ " \"distilbert-base-uncased\",\n",
452
+ "]"
453
+ ],
454
+ "outputs": [],
455
+ "execution_count": 15,
456
+ "metadata": {
457
+ "gather": {
458
+ "logged": 1706550382318
459
+ }
460
+ }
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "source": [
465
+ "BATCH_SIZE=1\n",
466
+ "\n",
467
+ "train_from_model(\"biobert/Bio_ClinicalBERT/\")"
468
+ ],
469
+ "outputs": [
470
+ {
471
+ "output_type": "stream",
472
+ "name": "stdout",
473
+ "text": "Initialising training based on biobert/Bio_ClinicalBERT/...\nTokenising...\nLoading model...\n"
474
+ },
475
+ {
476
+ "output_type": "stream",
477
+ "name": "stderr",
478
+ "text": "Map: 100%|██████████| 2722/2722 [00:01<00:00, 2195.12 examples/s]\nAll TF 2.0 model weights were used when initializing BertForSequenceClassification.\n\nAll the weights of BertForSequenceClassification were initialized from the TF 2.0 model.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.\n"
479
+ },
480
+ {
481
+ "output_type": "display_data",
482
+ "data": {
483
+ "text/plain": "<IPython.core.display.HTML object>",
484
+ "text/html": "Finishing last run (ID:sg022tqh) before initializing another..."
485
+ },
486
+ "metadata": {}
487
+ },
488
+ {
489
+ "output_type": "display_data",
490
+ "data": {
491
+ "text/plain": "<IPython.core.display.HTML object>",
492
+ "text/html": " View run <strong style=\"color:#cdcd00\">daedra_0.01-biobert/Bio_ClinicalBERT/</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
493
+ },
494
+ "metadata": {}
495
+ },
496
+ {
497
+ "output_type": "display_data",
498
+ "data": {
499
+ "text/plain": "<IPython.core.display.HTML object>",
500
+ "text/html": "Find logs at: <code>./wandb/run-20240129_174816-sg022tqh/logs</code>"
501
+ },
502
+ "metadata": {}
503
+ },
504
+ {
505
+ "output_type": "display_data",
506
+ "data": {
507
+ "text/plain": "<IPython.core.display.HTML object>",
508
+ "text/html": "Successfully finished last run (ID:sg022tqh). Initializing new run:<br/>"
509
+ },
510
+ "metadata": {}
511
+ },
512
+ {
513
+ "output_type": "display_data",
514
+ "data": {
515
+ "text/plain": "<IPython.core.display.HTML object>",
516
+ "text/html": "Tracking run with wandb version 0.16.2"
517
+ },
518
+ "metadata": {}
519
+ },
520
+ {
521
+ "output_type": "display_data",
522
+ "data": {
523
+ "text/plain": "<IPython.core.display.HTML object>",
524
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_174936-kilkkg1j</code>"
525
+ },
526
+ "metadata": {}
527
+ },
528
+ {
529
+ "output_type": "display_data",
530
+ "data": {
531
+ "text/plain": "<IPython.core.display.HTML object>",
532
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">daedra_0.01-biobert/Bio_ClinicalBERT/</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
533
+ },
534
+ "metadata": {}
535
+ },
536
+ {
537
+ "output_type": "display_data",
538
+ "data": {
539
+ "text/plain": "<IPython.core.display.HTML object>",
540
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
541
+ },
542
+ "metadata": {}
543
+ },
544
+ {
545
+ "output_type": "display_data",
546
+ "data": {
547
+ "text/plain": "<IPython.core.display.HTML object>",
548
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j</a>"
549
+ },
550
+ "metadata": {}
551
+ },
552
+ {
553
+ "output_type": "stream",
554
+ "name": "stdout",
555
+ "text": "Starting training...\n"
556
+ },
557
+ {
558
+ "output_type": "stream",
559
+ "name": "stderr",
560
+ "text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
561
+ },
562
+ {
563
+ "output_type": "display_data",
564
+ "data": {
565
+ "text/plain": "<IPython.core.display.HTML object>",
566
+ "text/html": "\n <div>\n \n <progress value='1496' max='15880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1496/15880 07:43 < 1:14:19, 3.23 it/s, Epoch 0.47/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
567
+ },
568
+ "metadata": {}
569
+ }
570
+ ],
571
+ "execution_count": 21,
572
+ "metadata": {
573
+ "gather": {
574
+ "logged": 1706551053473
575
+ }
576
+ }
577
+ },
578
+ {
579
+ "cell_type": "code",
580
+ "source": [],
581
+ "outputs": [],
582
+ "execution_count": null,
583
+ "metadata": {
584
+ "jupyter": {
585
+ "source_hidden": false,
586
+ "outputs_hidden": false
587
+ },
588
+ "nteract": {
589
+ "transient": {
590
+ "deleting": false
591
+ }
592
+ }
593
+ }
594
+ }
595
+ ],
596
+ "metadata": {
597
+ "datalore": {
598
+ "base_environment": "default",
599
+ "computation_mode": "JUPYTER",
600
+ "package_manager": "pip",
601
+ "packages": [
602
+ {
603
+ "name": "datasets",
604
+ "source": "PIP",
605
+ "version": "2.16.1"
606
+ },
607
+ {
608
+ "name": "torch",
609
+ "source": "PIP",
610
+ "version": "2.1.2"
611
+ },
612
+ {
613
+ "name": "accelerate",
614
+ "source": "PIP",
615
+ "version": "0.26.1"
616
+ }
617
+ ],
618
+ "report_row_ids": [
619
+ "un8W7ez7ZwoGb5Co6nydEV",
620
+ "40nN9Hvgi1clHNV5RAemI5",
621
+ "TgRD90H5NSPpKS41OeXI1w",
622
+ "ZOm5BfUs3h1EGLaUkBGeEB",
623
+ "kOP0CZWNSk6vqE3wkPp7Vc",
624
+ "W4PWcOu2O2pRaZyoE2W80h",
625
+ "RolbOnQLIftk0vy9mIcz5M",
626
+ "8OPhUgbaNJmOdiq5D3a6vK",
627
+ "5Qrt3jSvSrpK6Ne1hS6shL",
628
+ "hTq7nFUrovN5Ao4u6dIYWZ",
629
+ "I8WNZLpJ1DVP2wiCW7YBIB",
630
+ "SawhU3I9BewSE1XBPstpNJ",
631
+ "80EtLEl2FIE4FqbWnUD3nT"
632
+ ],
633
+ "version": 3
634
+ },
635
+ "kernel_info": {
636
+ "name": "python38-azureml-pt-tf"
637
+ },
638
+ "kernelspec": {
639
+ "display_name": "azureml_py38_PT_TF",
640
+ "language": "python",
641
+ "name": "python3"
642
+ },
643
+ "language_info": {
644
+ "name": "python",
645
+ "version": "3.8.5",
646
+ "mimetype": "text/x-python",
647
+ "codemirror_mode": {
648
+ "name": "ipython",
649
+ "version": 3
650
+ },
651
+ "pygments_lexer": "ipython3",
652
+ "nbconvert_exporter": "python",
653
+ "file_extension": ".py"
654
+ },
655
+ "microsoft": {
656
+ "host": {
657
+ "AzureML": {
658
+ "notebookHasBeenCompleted": true
659
+ }
660
+ },
661
+ "ms_spell_check": {
662
+ "ms_spell_check_language": "en"
663
+ }
664
+ },
665
+ "nteract": {
666
+ "version": "nteract-front-end@1.0.0"
667
+ }
668
+ },
669
+ "nbformat": 4,
670
+ "nbformat_minor": 4
671
+ }
notebooks/.ipynb_aml_checkpoints/microsample_model_comparison-checkpoint2024-0-31-14-6-22Z.ipynb ADDED
File without changes
notebooks/DAEDRA-Copy1.ipynb ADDED
@@ -0,0 +1,1634 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "metadata": {},
6
+ "source": [
7
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
8
+ "\n",
9
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
10
+ ]
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "execution_count": 1,
15
+ "metadata": {
16
+ "nteract": {
17
+ "transient": {
18
+ "deleting": false
19
+ }
20
+ },
21
+ "tags": []
22
+ },
23
+ "outputs": [
24
+ {
25
+ "name": "stdout",
26
+ "output_type": "stream",
27
+ "text": [
28
+ "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
29
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
30
+ "Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
31
+ "Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
32
+ "Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
33
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
34
+ "Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
35
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
36
+ "Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
37
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
38
+ "Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
39
+ "Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
40
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
41
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
42
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
43
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
44
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
45
+ "Note: you may need to restart the kernel to use updated packages.\n"
46
+ ]
47
+ }
48
+ ],
49
+ "source": [
50
+ "# %pip install accelerate -U"
51
+ ]
52
+ },
53
+ {
54
+ "cell_type": "code",
55
+ "execution_count": 2,
56
+ "metadata": {
57
+ "collapsed": true,
58
+ "jupyter": {
59
+ "outputs_hidden": true,
60
+ "source_hidden": false
61
+ },
62
+ "nteract": {
63
+ "transient": {
64
+ "deleting": false
65
+ }
66
+ }
67
+ },
68
+ "outputs": [
69
+ {
70
+ "name": "stdout",
71
+ "output_type": "stream",
72
+ "text": [
73
+ "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
74
+ "Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
75
+ "Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
76
+ "Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
77
+ "Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
78
+ "Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
79
+ "Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
80
+ "Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
81
+ "Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
82
+ "Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
83
+ "Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
84
+ "Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
85
+ "Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
86
+ "Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
87
+ "Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
88
+ "Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
89
+ "Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
90
+ "Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
91
+ "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
92
+ "Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
93
+ "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
94
+ "Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
95
+ "Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
96
+ "Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
97
+ "Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
98
+ "Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
99
+ "Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
100
+ "Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
101
+ "Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
102
+ "Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
103
+ "Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
104
+ "Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
105
+ "Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
106
+ "Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
107
+ "Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
108
+ "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
109
+ "Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
110
+ "Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
111
+ "Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
112
+ "Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
113
+ "Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
114
+ "Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
115
+ "Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
116
+ "Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
117
+ "Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
118
+ "Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
119
+ "Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
120
+ "Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
121
+ "Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
122
+ "Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
123
+ "Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
124
+ "Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
125
+ "Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
126
+ "Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
127
+ "Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
128
+ "Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
129
+ "Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
130
+ "Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
131
+ "Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
132
+ "Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
133
+ "Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
134
+ "Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
135
+ "Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
136
+ "Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
137
+ "Requirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n",
138
+ "Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
139
+ "Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
140
+ "Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
141
+ "Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
142
+ "Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
143
+ "Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
144
+ "Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
145
+ "Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
146
+ "Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
147
+ "Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
148
+ "Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
149
+ "Note: you may need to restart the kernel to use updated packages.\n"
150
+ ]
151
+ }
152
+ ],
153
+ "source": [
154
+ "# %pip install transformers datasets shap watermark wandb"
155
+ ]
156
+ },
157
+ {
158
+ "cell_type": "code",
159
+ "execution_count": 66,
160
+ "metadata": {
161
+ "datalore": {
162
+ "hide_input_from_viewers": false,
163
+ "hide_output_from_viewers": false,
164
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
165
+ "report_properties": {
166
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
167
+ },
168
+ "type": "CODE"
169
+ },
170
+ "gather": {
171
+ "logged": 1706449625034
172
+ },
173
+ "tags": []
174
+ },
175
+ "outputs": [
176
+ {
177
+ "name": "stdout",
178
+ "output_type": "stream",
179
+ "text": [
180
+ "The watermark extension is already loaded. To reload it, use:\n",
181
+ " %reload_ext watermark\n"
182
+ ]
183
+ }
184
+ ],
185
+ "source": [
186
+ "import pandas as pd\n",
187
+ "import numpy as np\n",
188
+ "import torch\n",
189
+ "import os\n",
190
+ "from typing import List\n",
191
+ "from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
192
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
193
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
194
+ "from pyarrow import Table\n",
195
+ "import shap\n",
196
+ "import wandb\n",
197
+ "\n",
198
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
199
+ "\n",
200
+ "%load_ext watermark"
201
+ ]
202
+ },
203
+ {
204
+ "cell_type": "code",
205
+ "execution_count": 43,
206
+ "metadata": {
207
+ "collapsed": false,
208
+ "gather": {
209
+ "logged": 1706449721319
210
+ },
211
+ "jupyter": {
212
+ "outputs_hidden": false
213
+ }
214
+ },
215
+ "outputs": [],
216
+ "source": [
217
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
218
+ "\n",
219
+ "SEED: int = 42\n",
220
+ "\n",
221
+ "BATCH_SIZE: int = 32\n",
222
+ "EPOCHS: int = 3\n",
223
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
224
+ "\n",
225
+ "CLASS_NAMES: List[str] = [\"DIED\",\n",
226
+ " \"ER_VISIT\",\n",
227
+ " \"HOSPITAL\",\n",
228
+ " \"OFC_VISIT\",\n",
229
+ " #\"X_STAY\", # pruned\n",
230
+ " #\"DISABLE\", # pruned\n",
231
+ " #\"D_PRESENTED\" # pruned\n",
232
+ " ]\n",
233
+ "\n",
234
+ "\n",
235
+ "\n",
236
+ "\n",
237
+ "# WandB configuration\n",
238
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
239
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
240
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
241
+ ]
242
+ },
243
+ {
244
+ "cell_type": "code",
245
+ "execution_count": 44,
246
+ "metadata": {
247
+ "collapsed": false,
248
+ "jupyter": {
249
+ "outputs_hidden": false
250
+ }
251
+ },
252
+ "outputs": [
253
+ {
254
+ "name": "stdout",
255
+ "output_type": "stream",
256
+ "text": [
257
+ "shap : 0.44.1\n",
258
+ "torch : 1.12.0\n",
259
+ "logging: 0.5.1.2\n",
260
+ "numpy : 1.23.5\n",
261
+ "pandas : 2.0.2\n",
262
+ "re : 2.2.1\n",
263
+ "\n"
264
+ ]
265
+ }
266
+ ],
267
+ "source": [
268
+ "%watermark --iversion"
269
+ ]
270
+ },
271
+ {
272
+ "cell_type": "code",
273
+ "execution_count": 45,
274
+ "metadata": {
275
+ "datalore": {
276
+ "hide_input_from_viewers": true,
277
+ "hide_output_from_viewers": true,
278
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
279
+ "type": "CODE"
280
+ }
281
+ },
282
+ "outputs": [
283
+ {
284
+ "name": "stdout",
285
+ "output_type": "stream",
286
+ "text": [
287
+ "Sun Jan 28 13:54:22 2024 \n",
288
+ "+---------------------------------------------------------------------------------------+\n",
289
+ "| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
290
+ "|-----------------------------------------+----------------------+----------------------+\n",
291
+ "| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
292
+ "| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
293
+ "| | | MIG M. |\n",
294
+ "|=========================================+======================+======================|\n",
295
+ "| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
296
+ "| N/A 30C P0 38W / 250W | 12830MiB / 16384MiB | 0% Default |\n",
297
+ "| | | N/A |\n",
298
+ "+-----------------------------------------+----------------------+----------------------+\n",
299
+ "| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
300
+ "| N/A 30C P0 38W / 250W | 11960MiB / 16384MiB | 0% Default |\n",
301
+ "| | | N/A |\n",
302
+ "+-----------------------------------------+----------------------+----------------------+\n",
303
+ " \n",
304
+ "+---------------------------------------------------------------------------------------+\n",
305
+ "| Processes: |\n",
306
+ "| GPU GI CI PID Type Process name GPU Memory |\n",
307
+ "| ID ID Usage |\n",
308
+ "|=======================================================================================|\n",
309
+ "| 0 N/A N/A 11781 C .../envs/azureml_py38_PT_TF/bin/python 12826MiB |\n",
310
+ "| 1 N/A N/A 11781 C .../envs/azureml_py38_PT_TF/bin/python 11956MiB |\n",
311
+ "+---------------------------------------------------------------------------------------+\n"
312
+ ]
313
+ }
314
+ ],
315
+ "source": [
316
+ "!nvidia-smi"
317
+ ]
318
+ },
319
+ {
320
+ "cell_type": "markdown",
321
+ "metadata": {
322
+ "datalore": {
323
+ "hide_input_from_viewers": false,
324
+ "hide_output_from_viewers": false,
325
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
326
+ "report_properties": {
327
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
328
+ },
329
+ "type": "MD"
330
+ }
331
+ },
332
+ "source": [
333
+ "## Loading the data set"
334
+ ]
335
+ },
336
+ {
337
+ "cell_type": "code",
338
+ "execution_count": 46,
339
+ "metadata": {
340
+ "collapsed": false,
341
+ "gather": {
342
+ "logged": 1706449040507
343
+ },
344
+ "jupyter": {
345
+ "outputs_hidden": false
346
+ }
347
+ },
348
+ "outputs": [],
349
+ "source": [
350
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
351
+ ]
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "execution_count": 47,
356
+ "metadata": {
357
+ "collapsed": false,
358
+ "gather": {
359
+ "logged": 1706449044205
360
+ },
361
+ "jupyter": {
362
+ "outputs_hidden": false,
363
+ "source_hidden": false
364
+ },
365
+ "nteract": {
366
+ "transient": {
367
+ "deleting": false
368
+ }
369
+ }
370
+ },
371
+ "outputs": [
372
+ {
373
+ "data": {
374
+ "text/plain": [
375
+ "DatasetDict({\n",
376
+ " train: Dataset({\n",
377
+ " features: ['id', 'text', 'labels'],\n",
378
+ " num_rows: 1270444\n",
379
+ " })\n",
380
+ " test: Dataset({\n",
381
+ " features: ['id', 'text', 'labels'],\n",
382
+ " num_rows: 272238\n",
383
+ " })\n",
384
+ " val: Dataset({\n",
385
+ " features: ['id', 'text', 'labels'],\n",
386
+ " num_rows: 272238\n",
387
+ " })\n",
388
+ "})"
389
+ ]
390
+ },
391
+ "execution_count": 47,
392
+ "metadata": {},
393
+ "output_type": "execute_result"
394
+ }
395
+ ],
396
+ "source": [
397
+ "dataset"
398
+ ]
399
+ },
400
+ {
401
+ "cell_type": "code",
402
+ "execution_count": 70,
403
+ "metadata": {},
404
+ "outputs": [],
405
+ "source": [
406
+ "SUBSAMPLING: float = 0.1"
407
+ ]
408
+ },
409
+ {
410
+ "cell_type": "code",
411
+ "execution_count": 48,
412
+ "metadata": {
413
+ "collapsed": false,
414
+ "gather": {
415
+ "logged": 1706449378281
416
+ },
417
+ "jupyter": {
418
+ "outputs_hidden": false,
419
+ "source_hidden": false
420
+ },
421
+ "nteract": {
422
+ "transient": {
423
+ "deleting": false
424
+ }
425
+ }
426
+ },
427
+ "outputs": [],
428
+ "source": [
429
+ "def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n",
430
+ " res = DatasetDict()\n",
431
+ "\n",
432
+ " res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n",
433
+ " res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n",
434
+ " res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n",
435
+ " \n",
436
+ " return res"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": 49,
442
+ "metadata": {
443
+ "collapsed": false,
444
+ "gather": {
445
+ "logged": 1706449384162
446
+ },
447
+ "jupyter": {
448
+ "outputs_hidden": false,
449
+ "source_hidden": false
450
+ },
451
+ "nteract": {
452
+ "transient": {
453
+ "deleting": false
454
+ }
455
+ }
456
+ },
457
+ "outputs": [],
458
+ "source": [
459
+ "dataset = minisample(dataset, SUBSAMPLING)"
460
+ ]
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "execution_count": 50,
465
+ "metadata": {
466
+ "collapsed": false,
467
+ "gather": {
468
+ "logged": 1706449387981
469
+ },
470
+ "jupyter": {
471
+ "outputs_hidden": false,
472
+ "source_hidden": false
473
+ },
474
+ "nteract": {
475
+ "transient": {
476
+ "deleting": false
477
+ }
478
+ }
479
+ },
480
+ "outputs": [
481
+ {
482
+ "data": {
483
+ "text/plain": [
484
+ "DatasetDict({\n",
485
+ " train: Dataset({\n",
486
+ " features: ['id', 'text', 'labels'],\n",
487
+ " num_rows: 127044\n",
488
+ " })\n",
489
+ " test: Dataset({\n",
490
+ " features: ['id', 'text', 'labels'],\n",
491
+ " num_rows: 27224\n",
492
+ " })\n",
493
+ " val: Dataset({\n",
494
+ " features: ['id', 'text', 'labels'],\n",
495
+ " num_rows: 27224\n",
496
+ " })\n",
497
+ "})"
498
+ ]
499
+ },
500
+ "execution_count": 50,
501
+ "metadata": {},
502
+ "output_type": "execute_result"
503
+ }
504
+ ],
505
+ "source": [
506
+ "dataset"
507
+ ]
508
+ },
509
+ {
510
+ "cell_type": "markdown",
511
+ "metadata": {
512
+ "nteract": {
513
+ "transient": {
514
+ "deleting": false
515
+ }
516
+ }
517
+ },
518
+ "source": [
519
+ "We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
520
+ ]
521
+ },
522
+ {
523
+ "cell_type": "code",
524
+ "execution_count": 51,
525
+ "metadata": {
526
+ "collapsed": false,
527
+ "gather": {
528
+ "logged": 1706449443055
529
+ },
530
+ "jupyter": {
531
+ "outputs_hidden": false,
532
+ "source_hidden": false
533
+ },
534
+ "nteract": {
535
+ "transient": {
536
+ "deleting": false
537
+ }
538
+ }
539
+ },
540
+ "outputs": [],
541
+ "source": [
542
+ "ds = DatasetDict()\n",
543
+ "\n",
544
+ "for i in [\"test\", \"train\", \"val\"]:\n",
545
+ " tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
546
+ " ds[i] = Dataset(tab)\n",
547
+ "\n",
548
+ "dataset = ds"
549
+ ]
550
+ },
551
+ {
552
+ "cell_type": "markdown",
553
+ "metadata": {},
554
+ "source": [
555
+ "### Tokenisation and encoding"
556
+ ]
557
+ },
558
+ {
559
+ "cell_type": "code",
560
+ "execution_count": 52,
561
+ "metadata": {
562
+ "datalore": {
563
+ "hide_input_from_viewers": true,
564
+ "hide_output_from_viewers": true,
565
+ "node_id": "I7n646PIscsUZRoHu6m7zm",
566
+ "type": "CODE"
567
+ },
568
+ "gather": {
569
+ "logged": 1706449638377
570
+ }
571
+ },
572
+ "outputs": [],
573
+ "source": [
574
+ "tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
575
+ ]
576
+ },
577
+ {
578
+ "cell_type": "code",
579
+ "execution_count": 53,
580
+ "metadata": {
581
+ "datalore": {
582
+ "hide_input_from_viewers": true,
583
+ "hide_output_from_viewers": true,
584
+ "node_id": "QBLOSI0yVIslV7v7qX9ZC3",
585
+ "type": "CODE"
586
+ },
587
+ "gather": {
588
+ "logged": 1706449642580
589
+ }
590
+ },
591
+ "outputs": [],
592
+ "source": [
593
+ "def tokenize_and_encode(examples):\n",
594
+ " return tokenizer(examples[\"text\"], truncation=True)"
595
+ ]
596
+ },
597
+ {
598
+ "cell_type": "code",
599
+ "execution_count": 54,
600
+ "metadata": {
601
+ "datalore": {
602
+ "hide_input_from_viewers": true,
603
+ "hide_output_from_viewers": true,
604
+ "node_id": "slHeNysZOX9uWS9PB7jFDb",
605
+ "type": "CODE"
606
+ },
607
+ "gather": {
608
+ "logged": 1706449721161
609
+ }
610
+ },
611
+ "outputs": [
612
+ {
613
+ "name": "stderr",
614
+ "output_type": "stream",
615
+ "text": [
616
+ "Map: 100%|██████████| 27224/27224 [00:11<00:00, 2347.91 examples/s]\n",
617
+ "Map: 100%|██████████| 127044/127044 [00:52<00:00, 2417.41 examples/s]\n",
618
+ "Map: 100%|██████████| 27224/27224 [00:11<00:00, 2376.02 examples/s]\n"
619
+ ]
620
+ }
621
+ ],
622
+ "source": [
623
+ "cols = dataset[\"train\"].column_names\n",
624
+ "cols.remove(\"labels\")\n",
625
+ "ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
626
+ ]
627
+ },
628
+ {
629
+ "cell_type": "markdown",
630
+ "metadata": {},
631
+ "source": [
632
+ "### Training"
633
+ ]
634
+ },
635
+ {
636
+ "cell_type": "code",
637
+ "execution_count": 55,
638
+ "metadata": {
639
+ "datalore": {
640
+ "hide_input_from_viewers": true,
641
+ "hide_output_from_viewers": true,
642
+ "node_id": "itXWkbDw9sqbkMuDP84QoT",
643
+ "type": "CODE"
644
+ },
645
+ "gather": {
646
+ "logged": 1706449743072
647
+ }
648
+ },
649
+ "outputs": [],
650
+ "source": [
651
+ "class MultiLabelTrainer(Trainer):\n",
652
+ " def compute_loss(self, model, inputs, return_outputs=False):\n",
653
+ " labels = inputs.pop(\"labels\")\n",
654
+ " outputs = model(**inputs)\n",
655
+ " logits = outputs.logits\n",
656
+ " loss_fct = torch.nn.BCEWithLogitsLoss()\n",
657
+ " loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
658
+ " labels.float().view(-1, self.model.config.num_labels))\n",
659
+ " return (loss, outputs) if return_outputs else loss"
660
+ ]
661
+ },
662
+ {
663
+ "cell_type": "code",
664
+ "execution_count": 56,
665
+ "metadata": {
666
+ "datalore": {
667
+ "hide_input_from_viewers": true,
668
+ "hide_output_from_viewers": true,
669
+ "node_id": "ZQU7aW6TV45VmhHOQRzcnF",
670
+ "type": "CODE"
671
+ },
672
+ "gather": {
673
+ "logged": 1706449761205
674
+ }
675
+ },
676
+ "outputs": [
677
+ {
678
+ "name": "stderr",
679
+ "output_type": "stream",
680
+ "text": [
681
+ "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
682
+ "You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
683
+ ]
684
+ }
685
+ ],
686
+ "source": [
687
+ "model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
688
+ ]
689
+ },
690
+ {
691
+ "cell_type": "code",
692
+ "execution_count": 57,
693
+ "metadata": {
694
+ "datalore": {
695
+ "hide_input_from_viewers": true,
696
+ "hide_output_from_viewers": true,
697
+ "node_id": "swhgyyyxoGL8HjnXJtMuSW",
698
+ "type": "CODE"
699
+ },
700
+ "gather": {
701
+ "logged": 1706449761541
702
+ }
703
+ },
704
+ "outputs": [],
705
+ "source": [
706
+ "def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
707
+ " y_pred = torch.from_numpy(y_pred)\n",
708
+ " y_true = torch.from_numpy(y_true)\n",
709
+ "\n",
710
+ " if sigmoid:\n",
711
+ " y_pred = y_pred.sigmoid()\n",
712
+ "\n",
713
+ " return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
714
+ ]
715
+ },
716
+ {
717
+ "cell_type": "code",
718
+ "execution_count": 58,
719
+ "metadata": {
720
+ "datalore": {
721
+ "hide_input_from_viewers": true,
722
+ "hide_output_from_viewers": true,
723
+ "node_id": "1Uq3HtkaBxtHNAnSwit5cI",
724
+ "type": "CODE"
725
+ },
726
+ "gather": {
727
+ "logged": 1706449761720
728
+ }
729
+ },
730
+ "outputs": [],
731
+ "source": [
732
+ "def compute_metrics(eval_pred):\n",
733
+ " predictions, labels = eval_pred\n",
734
+ " return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
735
+ ]
736
+ },
737
+ {
738
+ "cell_type": "code",
739
+ "execution_count": 63,
740
+ "metadata": {
741
+ "datalore": {
742
+ "hide_input_from_viewers": true,
743
+ "hide_output_from_viewers": true,
744
+ "node_id": "1iPZOTKPwSkTgX5dORqT89",
745
+ "type": "CODE"
746
+ },
747
+ "gather": {
748
+ "logged": 1706449761893
749
+ }
750
+ },
751
+ "outputs": [],
752
+ "source": [
753
+ "args = TrainingArguments(\n",
754
+ " output_dir=\"vaers\",\n",
755
+ " evaluation_strategy=\"epoch\",\n",
756
+ " learning_rate=2e-5,\n",
757
+ " per_device_train_batch_size=BATCH_SIZE,\n",
758
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
759
+ " num_train_epochs=EPOCHS,\n",
760
+ " weight_decay=.01,\n",
761
+ " logging_steps=1,\n",
762
+ " run_name=f\"daedra-training\",\n",
763
+ " report_to=[\"wandb\"]\n",
764
+ ")"
765
+ ]
766
+ },
767
+ {
768
+ "cell_type": "code",
769
+ "execution_count": 64,
770
+ "metadata": {
771
+ "datalore": {
772
+ "hide_input_from_viewers": true,
773
+ "hide_output_from_viewers": true,
774
+ "node_id": "bnRkNvRYltLun6gCEgL7v0",
775
+ "type": "CODE"
776
+ },
777
+ "gather": {
778
+ "logged": 1706449769103
779
+ }
780
+ },
781
+ "outputs": [],
782
+ "source": [
783
+ "multi_label_trainer = MultiLabelTrainer(\n",
784
+ " model, \n",
785
+ " args, \n",
786
+ " train_dataset=ds_enc[\"train\"], \n",
787
+ " eval_dataset=ds_enc[\"test\"], \n",
788
+ " compute_metrics=compute_metrics, \n",
789
+ " tokenizer=tokenizer\n",
790
+ ")"
791
+ ]
792
+ },
793
+ {
794
+ "cell_type": "code",
795
+ "execution_count": 71,
796
+ "metadata": {
797
+ "datalore": {
798
+ "hide_input_from_viewers": true,
799
+ "hide_output_from_viewers": true,
800
+ "node_id": "LO54PlDkWQdFrzV25FvduB",
801
+ "type": "CODE"
802
+ },
803
+ "gather": {
804
+ "logged": 1706449880674
805
+ }
806
+ },
807
+ "outputs": [
808
+ {
809
+ "data": {
810
+ "text/html": [
811
+ "Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
812
+ ],
813
+ "text/plain": [
814
+ "<IPython.core.display.HTML object>"
815
+ ]
816
+ },
817
+ "metadata": {},
818
+ "output_type": "display_data"
819
+ },
820
+ {
821
+ "data": {
822
+ "text/html": [
823
+ "Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
824
+ ],
825
+ "text/plain": [
826
+ "<IPython.core.display.HTML object>"
827
+ ]
828
+ },
829
+ "metadata": {},
830
+ "output_type": "display_data"
831
+ },
832
+ {
833
+ "name": "stderr",
834
+ "output_type": "stream",
835
+ "text": [
836
+ "\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
837
+ ]
838
+ },
839
+ {
840
+ "data": {
841
+ "text/html": [
842
+ "Tracking run with wandb version 0.16.2"
843
+ ],
844
+ "text/plain": [
845
+ "<IPython.core.display.HTML object>"
846
+ ]
847
+ },
848
+ "metadata": {},
849
+ "output_type": "display_data"
850
+ },
851
+ {
852
+ "data": {
853
+ "text/html": [
854
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141352-spfdhiij</code>"
855
+ ],
856
+ "text/plain": [
857
+ "<IPython.core.display.HTML object>"
858
+ ]
859
+ },
860
+ "metadata": {},
861
+ "output_type": "display_data"
862
+ },
863
+ {
864
+ "data": {
865
+ "text/html": [
866
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
867
+ ],
868
+ "text/plain": [
869
+ "<IPython.core.display.HTML object>"
870
+ ]
871
+ },
872
+ "metadata": {},
873
+ "output_type": "display_data"
874
+ },
875
+ {
876
+ "data": {
877
+ "text/html": [
878
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
879
+ ],
880
+ "text/plain": [
881
+ "<IPython.core.display.HTML object>"
882
+ ]
883
+ },
884
+ "metadata": {},
885
+ "output_type": "display_data"
886
+ },
887
+ {
888
+ "data": {
889
+ "text/html": [
890
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij</a>"
891
+ ],
892
+ "text/plain": [
893
+ "<IPython.core.display.HTML object>"
894
+ ]
895
+ },
896
+ "metadata": {},
897
+ "output_type": "display_data"
898
+ },
899
+ {
900
+ "data": {
901
+ "text/html": [
902
+ "Finishing last run (ID:spfdhiij) before initializing another..."
903
+ ],
904
+ "text/plain": [
905
+ "<IPython.core.display.HTML object>"
906
+ ]
907
+ },
908
+ "metadata": {},
909
+ "output_type": "display_data"
910
+ },
911
+ {
912
+ "data": {
913
+ "text/html": [
914
+ " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
915
+ ],
916
+ "text/plain": [
917
+ "<IPython.core.display.HTML object>"
918
+ ]
919
+ },
920
+ "metadata": {},
921
+ "output_type": "display_data"
922
+ },
923
+ {
924
+ "data": {
925
+ "text/html": [
926
+ "Find logs at: <code>./wandb/run-20240128_141352-spfdhiij/logs</code>"
927
+ ],
928
+ "text/plain": [
929
+ "<IPython.core.display.HTML object>"
930
+ ]
931
+ },
932
+ "metadata": {},
933
+ "output_type": "display_data"
934
+ },
935
+ {
936
+ "data": {
937
+ "text/html": [
938
+ "Successfully finished last run (ID:spfdhiij). Initializing new run:<br/>"
939
+ ],
940
+ "text/plain": [
941
+ "<IPython.core.display.HTML object>"
942
+ ]
943
+ },
944
+ "metadata": {},
945
+ "output_type": "display_data"
946
+ },
947
+ {
948
+ "data": {
949
+ "text/html": [
950
+ "Tracking run with wandb version 0.16.2"
951
+ ],
952
+ "text/plain": [
953
+ "<IPython.core.display.HTML object>"
954
+ ]
955
+ },
956
+ "metadata": {},
957
+ "output_type": "display_data"
958
+ },
959
+ {
960
+ "data": {
961
+ "text/html": [
962
+ "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141354-mpe6cpuz</code>"
963
+ ],
964
+ "text/plain": [
965
+ "<IPython.core.display.HTML object>"
966
+ ]
967
+ },
968
+ "metadata": {},
969
+ "output_type": "display_data"
970
+ },
971
+ {
972
+ "data": {
973
+ "text/html": [
974
+ "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
975
+ ],
976
+ "text/plain": [
977
+ "<IPython.core.display.HTML object>"
978
+ ]
979
+ },
980
+ "metadata": {},
981
+ "output_type": "display_data"
982
+ },
983
+ {
984
+ "data": {
985
+ "text/html": [
986
+ " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
987
+ ],
988
+ "text/plain": [
989
+ "<IPython.core.display.HTML object>"
990
+ ]
991
+ },
992
+ "metadata": {},
993
+ "output_type": "display_data"
994
+ },
995
+ {
996
+ "data": {
997
+ "text/html": [
998
+ " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz</a>"
999
+ ],
1000
+ "text/plain": [
1001
+ "<IPython.core.display.HTML object>"
1002
+ ]
1003
+ },
1004
+ "metadata": {},
1005
+ "output_type": "display_data"
1006
+ },
1007
+ {
1008
+ "data": {
1009
+ "text/html": [
1010
+ "<style>\n",
1011
+ " table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
1012
+ " .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
1013
+ " .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
1014
+ " </style>\n",
1015
+ "<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>▁</td></tr><tr><td>eval/loss</td><td>▁</td></tr><tr><td>eval/runtime</td><td>▁</td></tr><tr><td>eval/samples_per_second</td><td>▁</td></tr><tr><td>eval/steps_per_second</td><td>▁</td></tr><tr><td>train/global_step</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>0.42136</td></tr><tr><td>eval/loss</td><td>0.69069</td></tr><tr><td>eval/runtime</td><td>79.1475</td></tr><tr><td>eval/samples_per_second</td><td>343.965</td></tr><tr><td>eval/steps_per_second</td><td>2.691</td></tr><tr><td>train/global_step</td><td>0</td></tr></table><br/></div></div>"
1016
+ ],
1017
+ "text/plain": [
1018
+ "<IPython.core.display.HTML object>"
1019
+ ]
1020
+ },
1021
+ "metadata": {},
1022
+ "output_type": "display_data"
1023
+ },
1024
+ {
1025
+ "data": {
1026
+ "text/html": [
1027
+ " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
1028
+ ],
1029
+ "text/plain": [
1030
+ "<IPython.core.display.HTML object>"
1031
+ ]
1032
+ },
1033
+ "metadata": {},
1034
+ "output_type": "display_data"
1035
+ },
1036
+ {
1037
+ "data": {
1038
+ "text/html": [
1039
+ "Find logs at: <code>./wandb/run-20240128_141354-mpe6cpuz/logs</code>"
1040
+ ],
1041
+ "text/plain": [
1042
+ "<IPython.core.display.HTML object>"
1043
+ ]
1044
+ },
1045
+ "metadata": {},
1046
+ "output_type": "display_data"
1047
+ }
1048
+ ],
1049
+ "source": [
1050
+ "if SUBSAMPLING != 1.0:\n",
1051
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
1052
+ "else:\n",
1053
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
1054
+ " \n",
1055
+ "wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n",
1056
+ "\n",
1057
+ "multi_label_trainer.evaluate()\n",
1058
+ "wandb.finish()"
1059
+ ]
1060
+ },
1061
+ {
1062
+ "cell_type": "code",
1063
+ "execution_count": 62,
1064
+ "metadata": {
1065
+ "datalore": {
1066
+ "hide_input_from_viewers": true,
1067
+ "hide_output_from_viewers": true,
1068
+ "node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
1069
+ "type": "CODE"
1070
+ },
1071
+ "gather": {
1072
+ "logged": 1706449934637
1073
+ }
1074
+ },
1075
+ "outputs": [
1076
+ {
1077
+ "ename": "RuntimeError",
1078
+ "evalue": "Caught RuntimeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py\", line 61, in _worker\n output = module(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 1002, in forward\n distilbert_output = self.distilbert(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 822, in forward\n return self.transformer(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 587, in forward\n layer_outputs = layer_module(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 513, in forward\n sa_output = self.attention(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 243, in forward\n scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)\nRuntimeError: CUDA out of memory. Tried to allocate 96.00 MiB (GPU 0; 15.77 GiB total capacity; 14.69 GiB already allocated; 5.12 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n",
1079
+ "output_type": "error",
1080
+ "traceback": [
1081
+ "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
1082
+ "\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
1083
+ "Cell \u001b[0;32mIn[62], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmulti_label_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
1084
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1537\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
1085
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:1869\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1866\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1868\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1869\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1872\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1873\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1874\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1875\u001b[0m ):\n\u001b[1;32m 1876\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1877\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
1086
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:2768\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2765\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 2767\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2768\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2770\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 2771\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
1087
+ "Cell \u001b[0;32mIn[55], line 4\u001b[0m, in \u001b[0;36mMultiLabelTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_loss\u001b[39m(\u001b[38;5;28mself\u001b[39m, model, inputs, return_outputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 3\u001b[0m labels \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits\n\u001b[1;32m 6\u001b[0m loss_fct \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mBCEWithLogitsLoss()\n",
1088
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
1089
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:168\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 167\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 168\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
1090
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:178\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas, inputs, kwargs):\n\u001b[0;32m--> 178\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
1091
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py:86\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 84\u001b[0m output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m---> 86\u001b[0m \u001b[43moutput\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 87\u001b[0m outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
1092
+ "File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/_utils.py:461\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 457\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 458\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 460\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 461\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
1093
+ "\u001b[0;31mRuntimeError\u001b[0m: Caught RuntimeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py\", line 61, in _worker\n output = module(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 1002, in forward\n distilbert_output = self.distilbert(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 822, in forward\n return self.transformer(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 587, in forward\n layer_outputs = layer_module(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 513, in forward\n sa_output = self.attention(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 243, in forward\n scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)\nRuntimeError: CUDA out of memory. Tried to allocate 96.00 MiB (GPU 0; 15.77 GiB total capacity; 14.69 GiB already allocated; 5.12 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n"
1094
+ ]
1095
+ }
1096
+ ],
1097
+ "source": [
1098
+ "if SUBSAMPLING != 1.0:\n",
1099
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
1100
+ "else:\n",
1101
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
1102
+ " \n",
1103
+ "wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n",
1104
+ "\n",
1105
+ "multi_label_trainer.train()\n",
1106
+ "wandb.finish()"
1107
+ ]
1108
+ },
1109
+ {
1110
+ "cell_type": "markdown",
1111
+ "metadata": {},
1112
+ "source": [
1113
+ "### Evaluation"
1114
+ ]
1115
+ },
1116
+ {
1117
+ "cell_type": "markdown",
1118
+ "metadata": {},
1119
+ "source": [
1120
+ "We instantiate a classifier `pipeline` and push it to CUDA."
1121
+ ]
1122
+ },
1123
+ {
1124
+ "cell_type": "code",
1125
+ "execution_count": null,
1126
+ "metadata": {
1127
+ "datalore": {
1128
+ "hide_input_from_viewers": true,
1129
+ "hide_output_from_viewers": true,
1130
+ "node_id": "kHoUdBeqcyVXDSGv54C4aE",
1131
+ "type": "CODE"
1132
+ },
1133
+ "gather": {
1134
+ "logged": 1706411459928
1135
+ }
1136
+ },
1137
+ "outputs": [],
1138
+ "source": [
1139
+ "classifier = pipeline(\"text-classification\", \n",
1140
+ " model, \n",
1141
+ " tokenizer=tokenizer, \n",
1142
+ " device=\"cuda:0\")"
1143
+ ]
1144
+ },
1145
+ {
1146
+ "cell_type": "markdown",
1147
+ "metadata": {},
1148
+ "source": [
1149
+ "We use the same tokenizer used for training to tokenize/encode the validation set."
1150
+ ]
1151
+ },
1152
+ {
1153
+ "cell_type": "code",
1154
+ "execution_count": null,
1155
+ "metadata": {
1156
+ "datalore": {
1157
+ "hide_input_from_viewers": true,
1158
+ "hide_output_from_viewers": true,
1159
+ "node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
1160
+ "type": "CODE"
1161
+ },
1162
+ "gather": {
1163
+ "logged": 1706411523285
1164
+ }
1165
+ },
1166
+ "outputs": [],
1167
+ "source": [
1168
+ "test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
1169
+ " max_length=None, \n",
1170
+ " padding='max_length', \n",
1171
+ " return_token_type_ids=True, \n",
1172
+ " truncation=True)"
1173
+ ]
1174
+ },
1175
+ {
1176
+ "cell_type": "markdown",
1177
+ "metadata": {},
1178
+ "source": [
1179
+ "Once we've made the data loadable by putting it into a `DataLoader`, we "
1180
+ ]
1181
+ },
1182
+ {
1183
+ "cell_type": "code",
1184
+ "execution_count": null,
1185
+ "metadata": {
1186
+ "datalore": {
1187
+ "hide_input_from_viewers": true,
1188
+ "hide_output_from_viewers": true,
1189
+ "node_id": "MWfGq2tTkJNzFiDoUPq2X7",
1190
+ "type": "CODE"
1191
+ },
1192
+ "gather": {
1193
+ "logged": 1706411543379
1194
+ }
1195
+ },
1196
+ "outputs": [],
1197
+ "source": [
1198
+ "test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
1199
+ " torch.tensor(test_encodings['attention_mask']), \n",
1200
+ " torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
1201
+ " torch.tensor(test_encodings['token_type_ids']))\n",
1202
+ "test_dataloader = torch.utils.data.DataLoader(test_data, \n",
1203
+ " sampler=torch.utils.data.SequentialSampler(test_data), \n",
1204
+ " batch_size=BATCH_SIZE)"
1205
+ ]
1206
+ },
1207
+ {
1208
+ "cell_type": "code",
1209
+ "execution_count": null,
1210
+ "metadata": {
1211
+ "datalore": {
1212
+ "hide_input_from_viewers": true,
1213
+ "hide_output_from_viewers": true,
1214
+ "node_id": "1SJCSrQTRCexFCNCIyRrzL",
1215
+ "type": "CODE"
1216
+ },
1217
+ "gather": {
1218
+ "logged": 1706411587843
1219
+ }
1220
+ },
1221
+ "outputs": [],
1222
+ "source": [
1223
+ "model.eval()\n",
1224
+ "\n",
1225
+ "logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
1226
+ "\n",
1227
+ "for i, batch in enumerate(test_dataloader):\n",
1228
+ " batch = tuple(t.to(device) for t in batch)\n",
1229
+ " \n",
1230
+ " # Unpack the inputs from our dataloader\n",
1231
+ " b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
1232
+ " \n",
1233
+ " with torch.no_grad():\n",
1234
+ " outs = model(b_input_ids, attention_mask=b_input_mask)\n",
1235
+ " b_logit_pred = outs[0]\n",
1236
+ " pred_label = torch.sigmoid(b_logit_pred)\n",
1237
+ "\n",
1238
+ " b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
1239
+ " pred_label = pred_label.to('cpu').numpy()\n",
1240
+ " b_labels = b_labels.to('cpu').numpy()\n",
1241
+ "\n",
1242
+ " tokenized_texts.append(b_input_ids)\n",
1243
+ " logit_preds.append(b_logit_pred)\n",
1244
+ " true_labels.append(b_labels)\n",
1245
+ " pred_labels.append(pred_label)\n",
1246
+ "\n",
1247
+ "# Flatten outputs\n",
1248
+ "tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
1249
+ "pred_labels = [item for sublist in pred_labels for item in sublist]\n",
1250
+ "true_labels = [item for sublist in true_labels for item in sublist]\n",
1251
+ "\n",
1252
+ "# Converting flattened binary values to boolean values\n",
1253
+ "true_bools = [tl == 1 for tl in true_labels]\n",
1254
+ "pred_bools = [pl > 0.50 for pl in pred_labels] "
1255
+ ]
1256
+ },
1257
+ {
1258
+ "cell_type": "markdown",
1259
+ "metadata": {},
1260
+ "source": [
1261
+ "We create a classification report:"
1262
+ ]
1263
+ },
1264
+ {
1265
+ "cell_type": "code",
1266
+ "execution_count": null,
1267
+ "metadata": {
1268
+ "datalore": {
1269
+ "hide_input_from_viewers": true,
1270
+ "hide_output_from_viewers": true,
1271
+ "node_id": "eBprrgF086mznPbPVBpOLS",
1272
+ "type": "CODE"
1273
+ },
1274
+ "gather": {
1275
+ "logged": 1706411588249
1276
+ }
1277
+ },
1278
+ "outputs": [],
1279
+ "source": [
1280
+ "print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
1281
+ "print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
1282
+ "clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
1283
+ "print(clf_report)"
1284
+ ]
1285
+ },
1286
+ {
1287
+ "cell_type": "markdown",
1288
+ "metadata": {},
1289
+ "source": [
1290
+ "Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
1291
+ ]
1292
+ },
1293
+ {
1294
+ "cell_type": "code",
1295
+ "execution_count": null,
1296
+ "metadata": {
1297
+ "datalore": {
1298
+ "hide_input_from_viewers": true,
1299
+ "hide_output_from_viewers": true,
1300
+ "node_id": "yELHY0IEwMlMw3x6e7hoD1",
1301
+ "type": "CODE"
1302
+ },
1303
+ "gather": {
1304
+ "logged": 1706411588638
1305
+ }
1306
+ },
1307
+ "outputs": [],
1308
+ "source": [
1309
+ "# Creating a map of class names from class numbers\n",
1310
+ "idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
1311
+ ]
1312
+ },
1313
+ {
1314
+ "cell_type": "code",
1315
+ "execution_count": null,
1316
+ "metadata": {
1317
+ "datalore": {
1318
+ "hide_input_from_viewers": true,
1319
+ "hide_output_from_viewers": true,
1320
+ "node_id": "jH0S35dDteUch01sa6me6e",
1321
+ "type": "CODE"
1322
+ },
1323
+ "gather": {
1324
+ "logged": 1706411589004
1325
+ }
1326
+ },
1327
+ "outputs": [],
1328
+ "source": [
1329
+ "true_label_idxs, pred_label_idxs = [], []\n",
1330
+ "\n",
1331
+ "for vals in true_bools:\n",
1332
+ " true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
1333
+ "for vals in pred_bools:\n",
1334
+ " pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
1335
+ ]
1336
+ },
1337
+ {
1338
+ "cell_type": "code",
1339
+ "execution_count": null,
1340
+ "metadata": {
1341
+ "datalore": {
1342
+ "hide_input_from_viewers": true,
1343
+ "hide_output_from_viewers": true,
1344
+ "node_id": "h4vHL8XdGpayZ6xLGJUF6F",
1345
+ "type": "CODE"
1346
+ },
1347
+ "gather": {
1348
+ "logged": 1706411589301
1349
+ }
1350
+ },
1351
+ "outputs": [],
1352
+ "source": [
1353
+ "true_label_texts, pred_label_texts = [], []\n",
1354
+ "\n",
1355
+ "for vals in true_label_idxs:\n",
1356
+ " if vals:\n",
1357
+ " true_label_texts.append([idx2label[val] for val in vals])\n",
1358
+ " else:\n",
1359
+ " true_label_texts.append(vals)\n",
1360
+ "\n",
1361
+ "for vals in pred_label_idxs:\n",
1362
+ " if vals:\n",
1363
+ " pred_label_texts.append([idx2label[val] for val in vals])\n",
1364
+ " else:\n",
1365
+ " pred_label_texts.append(vals)"
1366
+ ]
1367
+ },
1368
+ {
1369
+ "cell_type": "code",
1370
+ "execution_count": null,
1371
+ "metadata": {
1372
+ "datalore": {
1373
+ "hide_input_from_viewers": true,
1374
+ "hide_output_from_viewers": true,
1375
+ "node_id": "SxUmVHfQISEeptg1SawOmB",
1376
+ "type": "CODE"
1377
+ },
1378
+ "gather": {
1379
+ "logged": 1706411591952
1380
+ }
1381
+ },
1382
+ "outputs": [],
1383
+ "source": [
1384
+ "symptom_texts = [tokenizer.decode(text,\n",
1385
+ " skip_special_tokens=True,\n",
1386
+ " clean_up_tokenization_spaces=False) for text in tokenized_texts]"
1387
+ ]
1388
+ },
1389
+ {
1390
+ "cell_type": "code",
1391
+ "execution_count": null,
1392
+ "metadata": {
1393
+ "datalore": {
1394
+ "hide_input_from_viewers": true,
1395
+ "hide_output_from_viewers": true,
1396
+ "node_id": "BxFNigNGRLTOqraI55BPSH",
1397
+ "type": "CODE"
1398
+ },
1399
+ "gather": {
1400
+ "logged": 1706411592512
1401
+ }
1402
+ },
1403
+ "outputs": [],
1404
+ "source": [
1405
+ "comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
1406
+ " 'true_labels': true_label_texts, \n",
1407
+ " 'pred_labels':pred_label_texts})\n",
1408
+ "comparisons_df.to_csv('comparisons.csv')\n",
1409
+ "comparisons_df"
1410
+ ]
1411
+ },
1412
+ {
1413
+ "cell_type": "markdown",
1414
+ "metadata": {},
1415
+ "source": [
1416
+ "### Shapley analysis"
1417
+ ]
1418
+ },
1419
+ {
1420
+ "cell_type": "code",
1421
+ "execution_count": null,
1422
+ "metadata": {
1423
+ "datalore": {
1424
+ "hide_input_from_viewers": true,
1425
+ "hide_output_from_viewers": true,
1426
+ "node_id": "OpdZcoenX2HwzLdai7K5UA",
1427
+ "type": "CODE"
1428
+ },
1429
+ "gather": {
1430
+ "logged": 1706415109071
1431
+ }
1432
+ },
1433
+ "outputs": [],
1434
+ "source": [
1435
+ "explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
1436
+ ]
1437
+ },
1438
+ {
1439
+ "cell_type": "markdown",
1440
+ "metadata": {
1441
+ "nteract": {
1442
+ "transient": {
1443
+ "deleting": false
1444
+ }
1445
+ }
1446
+ },
1447
+ "source": [
1448
+ "#### Sampling correct predictions\n",
1449
+ "\n",
1450
+ "First, let's look at some correct predictions of deaths:"
1451
+ ]
1452
+ },
1453
+ {
1454
+ "cell_type": "code",
1455
+ "execution_count": null,
1456
+ "metadata": {
1457
+ "collapsed": false,
1458
+ "gather": {
1459
+ "logged": 1706414973990
1460
+ },
1461
+ "jupyter": {
1462
+ "outputs_hidden": false
1463
+ },
1464
+ "nteract": {
1465
+ "transient": {
1466
+ "deleting": false
1467
+ }
1468
+ }
1469
+ },
1470
+ "outputs": [],
1471
+ "source": [
1472
+ "correct_death_predictions = comparisons_df[comparisons_df['true_labels'].astype(str) == \"['DIED']\"]"
1473
+ ]
1474
+ },
1475
+ {
1476
+ "cell_type": "code",
1477
+ "execution_count": null,
1478
+ "metadata": {
1479
+ "collapsed": false,
1480
+ "gather": {
1481
+ "logged": 1706415114683
1482
+ },
1483
+ "jupyter": {
1484
+ "outputs_hidden": false
1485
+ },
1486
+ "nteract": {
1487
+ "transient": {
1488
+ "deleting": false
1489
+ }
1490
+ }
1491
+ },
1492
+ "outputs": [],
1493
+ "source": [
1494
+ "texts = [i[:512] for i in correct_death_predictions.sample(n=6).symptom_text]\n",
1495
+ "idxs = [i for i in range(len(texts))]\n",
1496
+ "\n",
1497
+ "d_s = Dataset(Table.from_arrays([idxs, texts], names=[\"idx\", \"texts\"]))"
1498
+ ]
1499
+ },
1500
+ {
1501
+ "cell_type": "code",
1502
+ "execution_count": null,
1503
+ "metadata": {
1504
+ "collapsed": false,
1505
+ "gather": {
1506
+ "logged": 1706415129229
1507
+ },
1508
+ "jupyter": {
1509
+ "outputs_hidden": false
1510
+ },
1511
+ "nteract": {
1512
+ "transient": {
1513
+ "deleting": false
1514
+ }
1515
+ }
1516
+ },
1517
+ "outputs": [],
1518
+ "source": [
1519
+ "shap_values = explainer(d_s[\"texts\"])"
1520
+ ]
1521
+ },
1522
+ {
1523
+ "cell_type": "code",
1524
+ "execution_count": null,
1525
+ "metadata": {
1526
+ "collapsed": false,
1527
+ "gather": {
1528
+ "logged": 1706415151494
1529
+ },
1530
+ "jupyter": {
1531
+ "outputs_hidden": false
1532
+ },
1533
+ "nteract": {
1534
+ "transient": {
1535
+ "deleting": false
1536
+ }
1537
+ }
1538
+ },
1539
+ "outputs": [],
1540
+ "source": [
1541
+ "shap.plots.text(shap_values)"
1542
+ ]
1543
+ },
1544
+ {
1545
+ "cell_type": "code",
1546
+ "execution_count": null,
1547
+ "metadata": {
1548
+ "collapsed": false,
1549
+ "jupyter": {
1550
+ "outputs_hidden": false
1551
+ },
1552
+ "nteract": {
1553
+ "transient": {
1554
+ "deleting": false
1555
+ }
1556
+ }
1557
+ },
1558
+ "outputs": [],
1559
+ "source": []
1560
+ }
1561
+ ],
1562
+ "metadata": {
1563
+ "datalore": {
1564
+ "base_environment": "default",
1565
+ "computation_mode": "JUPYTER",
1566
+ "package_manager": "pip",
1567
+ "packages": [
1568
+ {
1569
+ "name": "datasets",
1570
+ "source": "PIP",
1571
+ "version": "2.16.1"
1572
+ },
1573
+ {
1574
+ "name": "torch",
1575
+ "source": "PIP",
1576
+ "version": "2.1.2"
1577
+ },
1578
+ {
1579
+ "name": "accelerate",
1580
+ "source": "PIP",
1581
+ "version": "0.26.1"
1582
+ }
1583
+ ],
1584
+ "report_row_ids": [
1585
+ "un8W7ez7ZwoGb5Co6nydEV",
1586
+ "40nN9Hvgi1clHNV5RAemI5",
1587
+ "TgRD90H5NSPpKS41OeXI1w",
1588
+ "ZOm5BfUs3h1EGLaUkBGeEB",
1589
+ "kOP0CZWNSk6vqE3wkPp7Vc",
1590
+ "W4PWcOu2O2pRaZyoE2W80h",
1591
+ "RolbOnQLIftk0vy9mIcz5M",
1592
+ "8OPhUgbaNJmOdiq5D3a6vK",
1593
+ "5Qrt3jSvSrpK6Ne1hS6shL",
1594
+ "hTq7nFUrovN5Ao4u6dIYWZ",
1595
+ "I8WNZLpJ1DVP2wiCW7YBIB",
1596
+ "SawhU3I9BewSE1XBPstpNJ",
1597
+ "80EtLEl2FIE4FqbWnUD3nT"
1598
+ ],
1599
+ "version": 3
1600
+ },
1601
+ "kernelspec": {
1602
+ "display_name": "Python 3.8 - Pytorch and Tensorflow",
1603
+ "language": "python",
1604
+ "name": "python38-azureml-pt-tf"
1605
+ },
1606
+ "language_info": {
1607
+ "codemirror_mode": {
1608
+ "name": "ipython",
1609
+ "version": 3
1610
+ },
1611
+ "file_extension": ".py",
1612
+ "mimetype": "text/x-python",
1613
+ "name": "python",
1614
+ "nbconvert_exporter": "python",
1615
+ "pygments_lexer": "ipython3",
1616
+ "version": "3.8.5"
1617
+ },
1618
+ "microsoft": {
1619
+ "host": {
1620
+ "AzureML": {
1621
+ "notebookHasBeenCompleted": true
1622
+ }
1623
+ },
1624
+ "ms_spell_check": {
1625
+ "ms_spell_check_language": "en"
1626
+ }
1627
+ },
1628
+ "nteract": {
1629
+ "version": "nteract-front-end@1.0.0"
1630
+ }
1631
+ },
1632
+ "nbformat": 4,
1633
+ "nbformat_minor": 4
1634
+ }
notebooks/DAEDRA.ipynb ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {}
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "source": [
15
+ "%pip install accelerate -U"
16
+ ],
17
+ "outputs": [
18
+ {
19
+ "output_type": "stream",
20
+ "name": "stdout",
21
+ "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
22
+ }
23
+ ],
24
+ "execution_count": 1,
25
+ "metadata": {
26
+ "gather": {
27
+ "logged": 1706475754655
28
+ },
29
+ "nteract": {
30
+ "transient": {
31
+ "deleting": false
32
+ }
33
+ },
34
+ "tags": []
35
+ }
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "source": [
40
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
41
+ ],
42
+ "outputs": [
43
+ {
44
+ "output_type": "stream",
45
+ "name": "stdout",
46
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
47
+ }
48
+ ],
49
+ "execution_count": 2,
50
+ "metadata": {
51
+ "nteract": {
52
+ "transient": {
53
+ "deleting": false
54
+ }
55
+ }
56
+ }
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "import pandas as pd\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "import os\n",
65
+ "from typing import List, Union\n",
66
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
67
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
68
+ "import shap\n",
69
+ "import wandb\n",
70
+ "import evaluate\n",
71
+ "import logging\n",
72
+ "\n",
73
+ "wandb.finish()\n",
74
+ "\n",
75
+ "\n",
76
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
77
+ "\n",
78
+ "%load_ext watermark"
79
+ ],
80
+ "outputs": [
81
+ {
82
+ "output_type": "stream",
83
+ "name": "stderr",
84
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 17:46:15.020290: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 17:46:16.031641: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031779: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
85
+ }
86
+ ],
87
+ "execution_count": 3,
88
+ "metadata": {
89
+ "datalore": {
90
+ "hide_input_from_viewers": false,
91
+ "hide_output_from_viewers": false,
92
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
93
+ "report_properties": {
94
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
95
+ },
96
+ "type": "CODE"
97
+ },
98
+ "gather": {
99
+ "logged": 1706550378660
100
+ },
101
+ "tags": []
102
+ }
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "source": [
107
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
108
+ "\n",
109
+ "SEED: int = 42\n",
110
+ "\n",
111
+ "BATCH_SIZE: int = 32\n",
112
+ "EPOCHS: int = 5\n",
113
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
114
+ "\n",
115
+ "# WandB configuration\n",
116
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
117
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
118
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
119
+ ],
120
+ "outputs": [],
121
+ "execution_count": 4,
122
+ "metadata": {
123
+ "collapsed": false,
124
+ "gather": {
125
+ "logged": 1706550378812
126
+ },
127
+ "jupyter": {
128
+ "outputs_hidden": false
129
+ }
130
+ }
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "source": [
135
+ "%watermark --iversion"
136
+ ],
137
+ "outputs": [
138
+ {
139
+ "output_type": "stream",
140
+ "name": "stdout",
141
+ "text": "shap : 0.44.1\npandas : 2.0.2\nwandb : 0.16.2\nre : 2.2.1\nevaluate: 0.4.1\ntorch : 1.12.0\nnumpy : 1.23.5\nlogging : 0.5.1.2\n\n"
142
+ }
143
+ ],
144
+ "execution_count": 5,
145
+ "metadata": {
146
+ "collapsed": false,
147
+ "jupyter": {
148
+ "outputs_hidden": false
149
+ }
150
+ }
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "source": [
155
+ "!nvidia-smi"
156
+ ],
157
+ "outputs": [
158
+ {
159
+ "output_type": "stream",
160
+ "name": "stdout",
161
+ "text": "Mon Jan 29 17:46:18 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
162
+ }
163
+ ],
164
+ "execution_count": 6,
165
+ "metadata": {
166
+ "datalore": {
167
+ "hide_input_from_viewers": true,
168
+ "hide_output_from_viewers": true,
169
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
170
+ "type": "CODE"
171
+ }
172
+ }
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "source": [
177
+ "## Loading the data set"
178
+ ],
179
+ "metadata": {
180
+ "datalore": {
181
+ "hide_input_from_viewers": false,
182
+ "hide_output_from_viewers": false,
183
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
184
+ "report_properties": {
185
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
186
+ },
187
+ "type": "MD"
188
+ }
189
+ }
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "source": [
194
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
195
+ ],
196
+ "outputs": [],
197
+ "execution_count": 7,
198
+ "metadata": {
199
+ "collapsed": false,
200
+ "gather": {
201
+ "logged": 1706550381141
202
+ },
203
+ "jupyter": {
204
+ "outputs_hidden": false
205
+ }
206
+ }
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "source": [
211
+ "dataset"
212
+ ],
213
+ "outputs": [
214
+ {
215
+ "output_type": "execute_result",
216
+ "execution_count": 8,
217
+ "data": {
218
+ "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
219
+ },
220
+ "metadata": {}
221
+ }
222
+ ],
223
+ "execution_count": 8,
224
+ "metadata": {
225
+ "collapsed": false,
226
+ "gather": {
227
+ "logged": 1706550381303
228
+ },
229
+ "jupyter": {
230
+ "outputs_hidden": false,
231
+ "source_hidden": false
232
+ },
233
+ "nteract": {
234
+ "transient": {
235
+ "deleting": false
236
+ }
237
+ }
238
+ }
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "source": [
243
+ "SUBSAMPLING = 0.01\n",
244
+ "\n",
245
+ "if SUBSAMPLING < 1:\n",
246
+ " _ = DatasetDict()\n",
247
+ " for each in dataset.keys():\n",
248
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
249
+ "\n",
250
+ " dataset = _"
251
+ ],
252
+ "outputs": [],
253
+ "execution_count": 9,
254
+ "metadata": {
255
+ "gather": {
256
+ "logged": 1706550381472
257
+ }
258
+ }
259
+ },
260
+ {
261
+ "cell_type": "markdown",
262
+ "source": [
263
+ "## Tokenisation and encoding"
264
+ ],
265
+ "metadata": {}
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "source": [
270
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
271
+ " return ds_enc"
272
+ ],
273
+ "outputs": [],
274
+ "execution_count": 10,
275
+ "metadata": {
276
+ "gather": {
277
+ "logged": 1706550381637
278
+ }
279
+ }
280
+ },
281
+ {
282
+ "cell_type": "markdown",
283
+ "source": [
284
+ "## Evaluation metrics"
285
+ ],
286
+ "metadata": {}
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "source": [
291
+ "accuracy = evaluate.load(\"accuracy\")\n",
292
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
293
+ "f1 = evaluate.load(\"f1\")"
294
+ ],
295
+ "outputs": [],
296
+ "execution_count": 11,
297
+ "metadata": {
298
+ "gather": {
299
+ "logged": 1706550381778
300
+ }
301
+ }
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "source": [
306
+ "def compute_metrics(eval_pred):\n",
307
+ " predictions, labels = eval_pred\n",
308
+ " predictions = np.argmax(predictions, axis=1)\n",
309
+ " return {\n",
310
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
311
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
312
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
313
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
314
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
315
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
316
+ " }"
317
+ ],
318
+ "outputs": [],
319
+ "execution_count": 12,
320
+ "metadata": {
321
+ "gather": {
322
+ "logged": 1706550381891
323
+ }
324
+ }
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "source": [
329
+ "## Training"
330
+ ],
331
+ "metadata": {}
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "source": [
336
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
337
+ ],
338
+ "metadata": {}
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "source": [
343
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
344
+ ],
345
+ "outputs": [],
346
+ "execution_count": 13,
347
+ "metadata": {
348
+ "gather": {
349
+ "logged": 1706550382032
350
+ }
351
+ }
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "source": [
356
+ "def train_from_model(model_ckpt: str, push: bool = False):\n",
357
+ " print(f\"Initialising training based on {model_ckpt}...\")\n",
358
+ "\n",
359
+ " print(\"Tokenising...\")\n",
360
+ " tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
361
+ "\n",
362
+ " cols = dataset[\"train\"].column_names\n",
363
+ " cols.remove(\"label\")\n",
364
+ " ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
365
+ "\n",
366
+ " print(\"Loading model...\")\n",
367
+ " try:\n",
368
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
369
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
370
+ " id2label=label_map, \n",
371
+ " label2id={v:k for k,v in label_map.items()})\n",
372
+ " except OSError:\n",
373
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
374
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
375
+ " id2label=label_map, \n",
376
+ " label2id={v:k for k,v in label_map.items()},\n",
377
+ " from_tf=True)\n",
378
+ "\n",
379
+ "\n",
380
+ " args = TrainingArguments(\n",
381
+ " output_dir=\"vaers\",\n",
382
+ " evaluation_strategy=\"epoch\",\n",
383
+ " save_strategy=\"epoch\",\n",
384
+ " learning_rate=2e-5,\n",
385
+ " per_device_train_batch_size=BATCH_SIZE,\n",
386
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
387
+ " num_train_epochs=EPOCHS,\n",
388
+ " weight_decay=.01,\n",
389
+ " logging_steps=1,\n",
390
+ " load_best_model_at_end=True,\n",
391
+ " run_name=f\"daedra-training\",\n",
392
+ " report_to=[\"wandb\"])\n",
393
+ "\n",
394
+ " trainer = Trainer(\n",
395
+ " model=model,\n",
396
+ " args=args,\n",
397
+ " train_dataset=ds_enc[\"train\"],\n",
398
+ " eval_dataset=ds_enc[\"test\"],\n",
399
+ " tokenizer=tokenizer,\n",
400
+ " compute_metrics=compute_metrics)\n",
401
+ " \n",
402
+ " if SUBSAMPLING != 1.0:\n",
403
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
404
+ " else:\n",
405
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
406
+ "\n",
407
+ " wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
408
+ " wandb_tag.append(f\"base:{model_ckpt}\")\n",
409
+ " \n",
410
+ " wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
411
+ "\n",
412
+ " print(\"Starting training...\")\n",
413
+ "\n",
414
+ " trainer.train()\n",
415
+ "\n",
416
+ " print(\"Training finished.\")\n",
417
+ "\n",
418
+ " if push:\n",
419
+ " variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
420
+ " tokenizer._tokenizer.save(\"tokenizer.json\")\n",
421
+ " tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
422
+ " sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
423
+ "\n",
424
+ " model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
425
+ " variant=variant,\n",
426
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
427
+ ],
428
+ "outputs": [],
429
+ "execution_count": 14,
430
+ "metadata": {
431
+ "jupyter": {
432
+ "outputs_hidden": false,
433
+ "source_hidden": false
434
+ },
435
+ "nteract": {
436
+ "transient": {
437
+ "deleting": false
438
+ }
439
+ },
440
+ "gather": {
441
+ "logged": 1706550382160
442
+ }
443
+ }
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "source": [
448
+ "\n",
449
+ "base_models = [\n",
450
+ " \"bert-base-uncased\",\n",
451
+ " \"distilbert-base-uncased\",\n",
452
+ "]"
453
+ ],
454
+ "outputs": [],
455
+ "execution_count": 15,
456
+ "metadata": {
457
+ "gather": {
458
+ "logged": 1706550382318
459
+ }
460
+ }
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "source": [
465
+ "BATCH_SIZE=1\n",
466
+ "\n",
467
+ "train_from_model(\"biobert/Bio_ClinicalBERT/\")"
468
+ ],
469
+ "outputs": [
470
+ {
471
+ "output_type": "stream",
472
+ "name": "stdout",
473
+ "text": "Initialising training based on biobert/Bio_ClinicalBERT/...\nTokenising...\nLoading model...\n"
474
+ },
475
+ {
476
+ "output_type": "stream",
477
+ "name": "stderr",
478
+ "text": "Map: 100%|██████████| 2722/2722 [00:01<00:00, 2195.12 examples/s]\nAll TF 2.0 model weights were used when initializing BertForSequenceClassification.\n\nAll the weights of BertForSequenceClassification were initialized from the TF 2.0 model.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.\n"
479
+ },
480
+ {
481
+ "output_type": "display_data",
482
+ "data": {
483
+ "text/plain": "<IPython.core.display.HTML object>",
484
+ "text/html": "Finishing last run (ID:sg022tqh) before initializing another..."
485
+ },
486
+ "metadata": {}
487
+ },
488
+ {
489
+ "output_type": "display_data",
490
+ "data": {
491
+ "text/plain": "<IPython.core.display.HTML object>",
492
+ "text/html": " View run <strong style=\"color:#cdcd00\">daedra_0.01-biobert/Bio_ClinicalBERT/</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
493
+ },
494
+ "metadata": {}
495
+ },
496
+ {
497
+ "output_type": "display_data",
498
+ "data": {
499
+ "text/plain": "<IPython.core.display.HTML object>",
500
+ "text/html": "Find logs at: <code>./wandb/run-20240129_174816-sg022tqh/logs</code>"
501
+ },
502
+ "metadata": {}
503
+ },
504
+ {
505
+ "output_type": "display_data",
506
+ "data": {
507
+ "text/plain": "<IPython.core.display.HTML object>",
508
+ "text/html": "Successfully finished last run (ID:sg022tqh). Initializing new run:<br/>"
509
+ },
510
+ "metadata": {}
511
+ },
512
+ {
513
+ "output_type": "display_data",
514
+ "data": {
515
+ "text/plain": "<IPython.core.display.HTML object>",
516
+ "text/html": "Tracking run with wandb version 0.16.2"
517
+ },
518
+ "metadata": {}
519
+ },
520
+ {
521
+ "output_type": "display_data",
522
+ "data": {
523
+ "text/plain": "<IPython.core.display.HTML object>",
524
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_174936-kilkkg1j</code>"
525
+ },
526
+ "metadata": {}
527
+ },
528
+ {
529
+ "output_type": "display_data",
530
+ "data": {
531
+ "text/plain": "<IPython.core.display.HTML object>",
532
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">daedra_0.01-biobert/Bio_ClinicalBERT/</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
533
+ },
534
+ "metadata": {}
535
+ },
536
+ {
537
+ "output_type": "display_data",
538
+ "data": {
539
+ "text/plain": "<IPython.core.display.HTML object>",
540
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
541
+ },
542
+ "metadata": {}
543
+ },
544
+ {
545
+ "output_type": "display_data",
546
+ "data": {
547
+ "text/plain": "<IPython.core.display.HTML object>",
548
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j</a>"
549
+ },
550
+ "metadata": {}
551
+ },
552
+ {
553
+ "output_type": "stream",
554
+ "name": "stdout",
555
+ "text": "Starting training...\n"
556
+ },
557
+ {
558
+ "output_type": "stream",
559
+ "name": "stderr",
560
+ "text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
561
+ },
562
+ {
563
+ "output_type": "display_data",
564
+ "data": {
565
+ "text/plain": "<IPython.core.display.HTML object>",
566
+ "text/html": "\n <div>\n \n <progress value='1496' max='15880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1496/15880 07:43 < 1:14:19, 3.23 it/s, Epoch 0.47/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
567
+ },
568
+ "metadata": {}
569
+ }
570
+ ],
571
+ "execution_count": 21,
572
+ "metadata": {
573
+ "gather": {
574
+ "logged": 1706551053473
575
+ }
576
+ }
577
+ },
578
+ {
579
+ "cell_type": "code",
580
+ "source": [],
581
+ "outputs": [],
582
+ "execution_count": null,
583
+ "metadata": {
584
+ "jupyter": {
585
+ "source_hidden": false,
586
+ "outputs_hidden": false
587
+ },
588
+ "nteract": {
589
+ "transient": {
590
+ "deleting": false
591
+ }
592
+ }
593
+ }
594
+ }
595
+ ],
596
+ "metadata": {
597
+ "datalore": {
598
+ "base_environment": "default",
599
+ "computation_mode": "JUPYTER",
600
+ "package_manager": "pip",
601
+ "packages": [
602
+ {
603
+ "name": "datasets",
604
+ "source": "PIP",
605
+ "version": "2.16.1"
606
+ },
607
+ {
608
+ "name": "torch",
609
+ "source": "PIP",
610
+ "version": "2.1.2"
611
+ },
612
+ {
613
+ "name": "accelerate",
614
+ "source": "PIP",
615
+ "version": "0.26.1"
616
+ }
617
+ ],
618
+ "report_row_ids": [
619
+ "un8W7ez7ZwoGb5Co6nydEV",
620
+ "40nN9Hvgi1clHNV5RAemI5",
621
+ "TgRD90H5NSPpKS41OeXI1w",
622
+ "ZOm5BfUs3h1EGLaUkBGeEB",
623
+ "kOP0CZWNSk6vqE3wkPp7Vc",
624
+ "W4PWcOu2O2pRaZyoE2W80h",
625
+ "RolbOnQLIftk0vy9mIcz5M",
626
+ "8OPhUgbaNJmOdiq5D3a6vK",
627
+ "5Qrt3jSvSrpK6Ne1hS6shL",
628
+ "hTq7nFUrovN5Ao4u6dIYWZ",
629
+ "I8WNZLpJ1DVP2wiCW7YBIB",
630
+ "SawhU3I9BewSE1XBPstpNJ",
631
+ "80EtLEl2FIE4FqbWnUD3nT"
632
+ ],
633
+ "version": 3
634
+ },
635
+ "kernel_info": {
636
+ "name": "python38-azureml-pt-tf"
637
+ },
638
+ "kernelspec": {
639
+ "display_name": "azureml_py38_PT_TF",
640
+ "language": "python",
641
+ "name": "python3"
642
+ },
643
+ "language_info": {
644
+ "name": "python",
645
+ "version": "3.8.5",
646
+ "mimetype": "text/x-python",
647
+ "codemirror_mode": {
648
+ "name": "ipython",
649
+ "version": 3
650
+ },
651
+ "pygments_lexer": "ipython3",
652
+ "nbconvert_exporter": "python",
653
+ "file_extension": ".py"
654
+ },
655
+ "microsoft": {
656
+ "host": {
657
+ "AzureML": {
658
+ "notebookHasBeenCompleted": true
659
+ }
660
+ },
661
+ "ms_spell_check": {
662
+ "ms_spell_check_language": "en"
663
+ }
664
+ },
665
+ "nteract": {
666
+ "version": "nteract-front-end@1.0.0"
667
+ }
668
+ },
669
+ "nbformat": 4,
670
+ "nbformat_minor": 4
671
+ }
notebooks/DAEDRA.yml ADDED
File without changes
notebooks/Dataset preparation.ipynb ADDED
@@ -0,0 +1,524 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# Dataset processing\n",
7
+ "\n",
8
+ "This notebook processes the raw csv outputs from VAERS into Huggingface datasets. It shouldn't generally need to be run by the end user. "
9
+ ],
10
+ "metadata": {
11
+ "collapsed": false
12
+ },
13
+ "id": "35523bbeb2e03eae"
14
+ },
15
+ {
16
+ "cell_type": "code",
17
+ "outputs": [],
18
+ "source": [
19
+ "import pandas as pd\n",
20
+ "import datasets\n",
21
+ "import glob\n",
22
+ "import tqdm.notebook as tqdm\n",
23
+ "from sklearn.model_selection import train_test_split\n",
24
+ "from typing import Tuple\n",
25
+ "from datetime import datetime\n",
26
+ "\n",
27
+ "pd.set_option('future.no_silent_downcasting', True)"
28
+ ],
29
+ "metadata": {
30
+ "collapsed": false,
31
+ "ExecuteTime": {
32
+ "end_time": "2024-01-27T22:28:38.481853Z",
33
+ "start_time": "2024-01-27T22:28:38.458294Z"
34
+ }
35
+ },
36
+ "id": "9362802d64424442",
37
+ "execution_count": 15
38
+ },
39
+ {
40
+ "cell_type": "code",
41
+ "outputs": [],
42
+ "source": [
43
+ "HF_URL: str = \"chrisvoncsefalvay/vaers-outcomes\"\n",
44
+ "\n",
45
+ "FLAG_COLUMNS: list = [\"DIED\", \"ER_VISIT\", \"HOSPITAL\", \"OFC_VISIT\", \"X_STAY\", \"DISABLE\"]\n",
46
+ "DEMOGRAPHIC_COLUMNS: list = [\"AGE_YRS\", \"SEX\"]\n",
47
+ "DERIVED_COLUMNS: list = [\"D_PRESENTED\"]\n",
48
+ "ID_COLUMNS: list = [\"VAERS_ID\"]\n",
49
+ "TEXT_COLUMNS: list = [\"SYMPTOM_TEXT\"]\n",
50
+ "\n",
51
+ "TEST_TRAIN_FRACTION: float = 0.3\n",
52
+ "TRAIN_VAL_FRACTION: float = 0.5"
53
+ ],
54
+ "metadata": {
55
+ "collapsed": false,
56
+ "ExecuteTime": {
57
+ "end_time": "2024-01-27T22:28:38.498974Z",
58
+ "start_time": "2024-01-27T22:28:38.486237Z"
59
+ }
60
+ },
61
+ "id": "34b77edf5a1fce96",
62
+ "execution_count": 16
63
+ },
64
+ {
65
+ "cell_type": "markdown",
66
+ "source": [
67
+ "## Reading data files"
68
+ ],
69
+ "metadata": {
70
+ "collapsed": false
71
+ },
72
+ "id": "f5f84ddd06e9313e"
73
+ },
74
+ {
75
+ "cell_type": "code",
76
+ "outputs": [],
77
+ "source": [
78
+ "def read_aggregate(pattern: str) -> pd.DataFrame:\n",
79
+ " files = glob.glob(f\"../data/{pattern}\")\n",
80
+ " dfs = []\n",
81
+ " for file in tqdm.tqdm(files):\n",
82
+ " dfs.append(pd.read_csv(file, encoding=\"latin-1\", low_memory=False))\n",
83
+ "\n",
84
+ " res = pd.concat(dfs, ignore_index=True)\n",
85
+ " \n",
86
+ " print(f\"Processed {len(dfs)} files for a total of {len(res)} records.\")\n",
87
+ " \n",
88
+ " return res"
89
+ ],
90
+ "metadata": {
91
+ "collapsed": false,
92
+ "ExecuteTime": {
93
+ "end_time": "2024-01-27T22:28:38.508227Z",
94
+ "start_time": "2024-01-27T22:28:38.500697Z"
95
+ }
96
+ },
97
+ "id": "a7772ed4b4b51868",
98
+ "execution_count": 17
99
+ },
100
+ {
101
+ "cell_type": "code",
102
+ "outputs": [
103
+ {
104
+ "data": {
105
+ "text/plain": " 0%| | 0/1 [00:00<?, ?it/s]",
106
+ "application/vnd.jupyter.widget-view+json": {
107
+ "version_major": 2,
108
+ "version_minor": 0,
109
+ "model_id": "8a6919ed3c7e4c3a8885bb0991e856c7"
110
+ }
111
+ },
112
+ "metadata": {},
113
+ "output_type": "display_data"
114
+ },
115
+ {
116
+ "name": "stdout",
117
+ "output_type": "stream",
118
+ "text": [
119
+ "Processed 1 files for a total of 105726 records.\n"
120
+ ]
121
+ }
122
+ ],
123
+ "source": [
124
+ "data = read_aggregate(\"*VAERSDATA.csv\")"
125
+ ],
126
+ "metadata": {
127
+ "collapsed": false,
128
+ "ExecuteTime": {
129
+ "end_time": "2024-01-27T22:28:39.567031Z",
130
+ "start_time": "2024-01-27T22:28:38.510939Z"
131
+ }
132
+ },
133
+ "id": "795e389489cbc6cf",
134
+ "execution_count": 18
135
+ },
136
+ {
137
+ "cell_type": "code",
138
+ "outputs": [],
139
+ "source": [
140
+ "_keep: list = ID_COLUMNS + DEMOGRAPHIC_COLUMNS + TEXT_COLUMNS + FLAG_COLUMNS + [\"ER_ED_VISIT\"]\n",
141
+ "data = data[_keep]"
142
+ ],
143
+ "metadata": {
144
+ "collapsed": false,
145
+ "ExecuteTime": {
146
+ "end_time": "2024-01-27T22:28:39.603326Z",
147
+ "start_time": "2024-01-27T22:28:39.569131Z"
148
+ }
149
+ },
150
+ "id": "5297fca83e18b502",
151
+ "execution_count": 19
152
+ },
153
+ {
154
+ "cell_type": "markdown",
155
+ "source": [
156
+ "## Recoding\n",
157
+ "\n",
158
+ "We recode as follows:\n",
159
+ "\n",
160
+ "* For the outcome flags, `NaN` is recoded as `0` and `Y` is recoded as `1`.\n",
161
+ "* `ER_VISIT` and `ER_ED_VISIT` are coalesced into a single column called `ER_VISIT` that is `1`-valued if either is `1`-valued, otherwise it is `0`-valued. This is to manage the renaming of the column in the VAERS data.\n",
162
+ "* `NaN`s in the symptom text will drop the record."
163
+ ],
164
+ "metadata": {
165
+ "collapsed": false
166
+ },
167
+ "id": "9467a8081810458e"
168
+ },
169
+ {
170
+ "cell_type": "code",
171
+ "outputs": [],
172
+ "source": [
173
+ "def recode(df: pd.DataFrame) -> pd.DataFrame:\n",
174
+ " for column in FLAG_COLUMNS + [\"ER_ED_VISIT\"]:\n",
175
+ " df[column] = df[column].replace(\"Y\", 1).fillna(0).astype(int)\n",
176
+ " \n",
177
+ " df['ER_VISIT'] = df[['ER_VISIT', 'ER_ED_VISIT']].max(axis=1)\n",
178
+ " \n",
179
+ " df = df.drop(columns=['ER_ED_VISIT'])\n",
180
+ " \n",
181
+ " df = df.dropna(subset=['SYMPTOM_TEXT'])\n",
182
+ " \n",
183
+ " return df"
184
+ ],
185
+ "metadata": {
186
+ "collapsed": false,
187
+ "ExecuteTime": {
188
+ "end_time": "2024-01-27T22:28:39.603731Z",
189
+ "start_time": "2024-01-27T22:28:39.590617Z"
190
+ }
191
+ },
192
+ "id": "9aad00c9fe40adb8",
193
+ "execution_count": 20
194
+ },
195
+ {
196
+ "cell_type": "code",
197
+ "outputs": [],
198
+ "source": [],
199
+ "metadata": {
200
+ "collapsed": false,
201
+ "ExecuteTime": {
202
+ "end_time": "2024-01-27T22:28:39.604024Z",
203
+ "start_time": "2024-01-27T22:28:39.593891Z"
204
+ }
205
+ },
206
+ "id": "b0fdcab6ee807404",
207
+ "execution_count": 20
208
+ },
209
+ {
210
+ "cell_type": "code",
211
+ "outputs": [],
212
+ "source": [
213
+ "data = recode(data)"
214
+ ],
215
+ "metadata": {
216
+ "collapsed": false,
217
+ "ExecuteTime": {
218
+ "end_time": "2024-01-27T22:28:39.665777Z",
219
+ "start_time": "2024-01-27T22:28:39.597946Z"
220
+ }
221
+ },
222
+ "id": "f23ee0eae1b70387",
223
+ "execution_count": 21
224
+ },
225
+ {
226
+ "cell_type": "markdown",
227
+ "source": [
228
+ "## Derived fields\n",
229
+ "\n",
230
+ "We create the derived field `D_PRESENTED`. This is to provide a shorthand for patients who present in any way: ER, hospitalisation, office visit. It also comprises patients whose hospital stay is extended (`X_STAY`) as this is typically the consequence of presenting."
231
+ ],
232
+ "metadata": {
233
+ "collapsed": false
234
+ },
235
+ "id": "1c2f6b4fc2ae630b"
236
+ },
237
+ {
238
+ "cell_type": "code",
239
+ "outputs": [],
240
+ "source": [
241
+ "data['D_PRESENTED'] = data[['ER_VISIT', 'HOSPITAL', 'OFC_VISIT', 'X_STAY']].max(axis=1)"
242
+ ],
243
+ "metadata": {
244
+ "collapsed": false,
245
+ "ExecuteTime": {
246
+ "end_time": "2024-01-27T22:28:39.679534Z",
247
+ "start_time": "2024-01-27T22:28:39.667363Z"
248
+ }
249
+ },
250
+ "id": "678847c70756695e",
251
+ "execution_count": 22
252
+ },
253
+ {
254
+ "cell_type": "markdown",
255
+ "source": [
256
+ "## Test/train/validate split\n",
257
+ "\n",
258
+ "We do a stratified split by age quintile and gender into test, train and validate sets."
259
+ ],
260
+ "metadata": {
261
+ "collapsed": false
262
+ },
263
+ "id": "dae902b111c8ef3c"
264
+ },
265
+ {
266
+ "cell_type": "code",
267
+ "outputs": [],
268
+ "source": [
269
+ "def stratified_split(df: pd.DataFrame, test_train_fraction: float, train_val_fraction: float, random_state: int = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:\n",
270
+ " df['AGE_QUINTILE'] = pd.qcut(df['AGE_YRS'], 5, labels = False)\n",
271
+ " df['STRATIFICATION_VARIABLE'] = df['SEX'].astype(str) + \"_\" + df['AGE_QUINTILE'].astype(str)\n",
272
+ " df = df.drop(columns=['AGE_QUINTILE'])\n",
273
+ " \n",
274
+ " _, train = train_test_split(df, train_size=test_train_fraction, random_state=random_state, stratify=df.STRATIFICATION_VARIABLE)\n",
275
+ " \n",
276
+ " val, test = train_test_split(_, train_size=train_val_fraction, random_state=random_state, stratify=_.STRATIFICATION_VARIABLE)\n",
277
+ " \n",
278
+ " train = train.drop(columns=\"STRATIFICATION_VARIABLE\")\n",
279
+ " val = val.drop(columns=\"STRATIFICATION_VARIABLE\")\n",
280
+ " test = test.drop(columns=\"STRATIFICATION_VARIABLE\") \n",
281
+ " \n",
282
+ " return train, test, val"
283
+ ],
284
+ "metadata": {
285
+ "collapsed": false,
286
+ "ExecuteTime": {
287
+ "end_time": "2024-01-27T22:28:39.680497Z",
288
+ "start_time": "2024-01-27T22:28:39.678055Z"
289
+ }
290
+ },
291
+ "id": "ddee47653c94ff02",
292
+ "execution_count": 23
293
+ },
294
+ {
295
+ "cell_type": "code",
296
+ "outputs": [],
297
+ "source": [
298
+ "train, test, val = stratified_split(data, TEST_TRAIN_FRACTION, TRAIN_VAL_FRACTION)"
299
+ ],
300
+ "metadata": {
301
+ "collapsed": false,
302
+ "ExecuteTime": {
303
+ "end_time": "2024-01-27T22:28:39.863489Z",
304
+ "start_time": "2024-01-27T22:28:39.680464Z"
305
+ }
306
+ },
307
+ "id": "bb16aaad0127ef7d",
308
+ "execution_count": 24
309
+ },
310
+ {
311
+ "cell_type": "markdown",
312
+ "source": [
313
+ "## Converting to labels"
314
+ ],
315
+ "metadata": {
316
+ "collapsed": false
317
+ },
318
+ "id": "d61bfdc4a2879905"
319
+ },
320
+ {
321
+ "cell_type": "code",
322
+ "outputs": [],
323
+ "source": [
324
+ "def convert_to_dataset(df: pd.DataFrame) -> datasets.Dataset:\n",
325
+ " df = df.loc[:, ID_COLUMNS + TEXT_COLUMNS + FLAG_COLUMNS + DERIVED_COLUMNS]\n",
326
+ " \n",
327
+ " # We create the labels – these have to be floats for multilabel classification that uses BCEWithLogitsLoss\n",
328
+ " df.loc[:, \"labels\"] = df[FLAG_COLUMNS + DERIVED_COLUMNS].values.astype(float).tolist()\n",
329
+ " \n",
330
+ " print(f\"Building dataset with the following label order: {' '.join(FLAG_COLUMNS + DERIVED_COLUMNS)}\")\n",
331
+ " \n",
332
+ " # We drop the flag columns\n",
333
+ " df = df.drop(columns=FLAG_COLUMNS).drop(columns=DERIVED_COLUMNS)\n",
334
+ " \n",
335
+ " # We rename the remaining columns\n",
336
+ " df = df.rename(columns={\"SYMPTOM_TEXT\": \"text\", \"VAERS_ID\": \"id\"})\n",
337
+ " \n",
338
+ " return datasets.Dataset.from_pandas(df, preserve_index=False)"
339
+ ],
340
+ "metadata": {
341
+ "collapsed": false,
342
+ "ExecuteTime": {
343
+ "end_time": "2024-01-27T22:28:39.867392Z",
344
+ "start_time": "2024-01-27T22:28:39.864829Z"
345
+ }
346
+ },
347
+ "id": "3d602444d33b7130",
348
+ "execution_count": 25
349
+ },
350
+ {
351
+ "cell_type": "code",
352
+ "outputs": [
353
+ {
354
+ "name": "stdout",
355
+ "output_type": "stream",
356
+ "text": [
357
+ "Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n",
358
+ "Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n",
359
+ "Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n"
360
+ ]
361
+ }
362
+ ],
363
+ "source": [
364
+ "ds = datasets.DatasetDict()\n",
365
+ "ds[\"train\"] = convert_to_dataset(train)\n",
366
+ "ds[\"test\"] = convert_to_dataset(test)\n",
367
+ "ds[\"val\"] = convert_to_dataset(val)"
368
+ ],
369
+ "metadata": {
370
+ "collapsed": false,
371
+ "ExecuteTime": {
372
+ "end_time": "2024-01-27T22:28:40.207548Z",
373
+ "start_time": "2024-01-27T22:28:39.872665Z"
374
+ }
375
+ },
376
+ "id": "e7c854a072956ca3",
377
+ "execution_count": 26
378
+ },
379
+ {
380
+ "cell_type": "markdown",
381
+ "source": [
382
+ "## Saving to Huggingface Hub"
383
+ ],
384
+ "metadata": {
385
+ "collapsed": false
386
+ },
387
+ "id": "ec0167c068238f5a"
388
+ },
389
+ {
390
+ "cell_type": "code",
391
+ "outputs": [
392
+ {
393
+ "data": {
394
+ "text/plain": "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]",
395
+ "application/vnd.jupyter.widget-view+json": {
396
+ "version_major": 2,
397
+ "version_minor": 0,
398
+ "model_id": "c196be983bbc474186dad4b75347aebb"
399
+ }
400
+ },
401
+ "metadata": {},
402
+ "output_type": "display_data"
403
+ },
404
+ {
405
+ "data": {
406
+ "text/plain": "Creating parquet from Arrow format: 0%| | 0/74 [00:00<?, ?ba/s]",
407
+ "application/vnd.jupyter.widget-view+json": {
408
+ "version_major": 2,
409
+ "version_minor": 0,
410
+ "model_id": "9bb3cbdfa4e84b96a68929fc3326536d"
411
+ }
412
+ },
413
+ "metadata": {},
414
+ "output_type": "display_data"
415
+ },
416
+ {
417
+ "data": {
418
+ "text/plain": "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]",
419
+ "application/vnd.jupyter.widget-view+json": {
420
+ "version_major": 2,
421
+ "version_minor": 0,
422
+ "model_id": "66aa46f327264d7aa8f42f4a1bcf0775"
423
+ }
424
+ },
425
+ "metadata": {},
426
+ "output_type": "display_data"
427
+ },
428
+ {
429
+ "data": {
430
+ "text/plain": "Creating parquet from Arrow format: 0%| | 0/16 [00:00<?, ?ba/s]",
431
+ "application/vnd.jupyter.widget-view+json": {
432
+ "version_major": 2,
433
+ "version_minor": 0,
434
+ "model_id": "b14c57836adc4a3692f9594acc164ff0"
435
+ }
436
+ },
437
+ "metadata": {},
438
+ "output_type": "display_data"
439
+ },
440
+ {
441
+ "data": {
442
+ "text/plain": "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]",
443
+ "application/vnd.jupyter.widget-view+json": {
444
+ "version_major": 2,
445
+ "version_minor": 0,
446
+ "model_id": "d395ca6f2c9b4ee5bb49dbce3a9bd064"
447
+ }
448
+ },
449
+ "metadata": {},
450
+ "output_type": "display_data"
451
+ },
452
+ {
453
+ "data": {
454
+ "text/plain": "Creating parquet from Arrow format: 0%| | 0/16 [00:00<?, ?ba/s]",
455
+ "application/vnd.jupyter.widget-view+json": {
456
+ "version_major": 2,
457
+ "version_minor": 0,
458
+ "model_id": "71780ee50ab649338bfa217f1767cca7"
459
+ }
460
+ },
461
+ "metadata": {},
462
+ "output_type": "display_data"
463
+ },
464
+ {
465
+ "data": {
466
+ "text/plain": "README.md: 0%| | 0.00/94.0 [00:00<?, ?B/s]",
467
+ "application/vnd.jupyter.widget-view+json": {
468
+ "version_major": 2,
469
+ "version_minor": 0,
470
+ "model_id": "1983f75eccf044649ab6423cad68dfdc"
471
+ }
472
+ },
473
+ "metadata": {},
474
+ "output_type": "display_data"
475
+ },
476
+ {
477
+ "data": {
478
+ "text/plain": "CommitInfo(commit_url='https://huggingface.co/datasets/chrisvoncsefalvay/vaers-outcomes/commit/65fa5129a0b1eb64f8fdd1aca5490965810e4ddb', commit_message='Data set commit of 105238 records of VAERS data at 2024-01-27T15:28:40.206686.', commit_description='', oid='65fa5129a0b1eb64f8fdd1aca5490965810e4ddb', pr_url='https://huggingface.co/datasets/chrisvoncsefalvay/vaers-outcomes/discussions/1', pr_revision='refs/pr/1', pr_num=1)"
479
+ },
480
+ "execution_count": 27,
481
+ "metadata": {},
482
+ "output_type": "execute_result"
483
+ }
484
+ ],
485
+ "source": [
486
+ "commit_message = f\"\"\"Data set commit of {len(train) + len(test) + len(val)} records of VAERS data at {datetime.now().isoformat()}.\"\"\"\n",
487
+ "\n",
488
+ "ds.push_to_hub(HF_URL, \n",
489
+ " commit_message=commit_message,\n",
490
+ " create_pr=True)"
491
+ ],
492
+ "metadata": {
493
+ "collapsed": false,
494
+ "ExecuteTime": {
495
+ "end_time": "2024-01-27T22:28:45.264233Z",
496
+ "start_time": "2024-01-27T22:28:40.207690Z"
497
+ }
498
+ },
499
+ "id": "104ffca720a27624",
500
+ "execution_count": 27
501
+ }
502
+ ],
503
+ "metadata": {
504
+ "kernelspec": {
505
+ "display_name": "Python 3",
506
+ "language": "python",
507
+ "name": "python3"
508
+ },
509
+ "language_info": {
510
+ "codemirror_mode": {
511
+ "name": "ipython",
512
+ "version": 2
513
+ },
514
+ "file_extension": ".py",
515
+ "mimetype": "text/x-python",
516
+ "name": "python",
517
+ "nbconvert_exporter": "python",
518
+ "pygments_lexer": "ipython2",
519
+ "version": "2.7.6"
520
+ }
521
+ },
522
+ "nbformat": 4,
523
+ "nbformat_minor": 5
524
+ }
notebooks/Untitled.ipynb ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "code",
5
+ "execution_count": null,
6
+ "id": "8e9c5e9c-af14-4148-86bb-b04f18e4d13e",
7
+ "metadata": {},
8
+ "outputs": [],
9
+ "source": []
10
+ }
11
+ ],
12
+ "metadata": {
13
+ "kernelspec": {
14
+ "display_name": "Python 3.8 - Pytorch and Tensorflow",
15
+ "language": "python",
16
+ "name": "python38-azureml-pt-tf"
17
+ },
18
+ "language_info": {
19
+ "codemirror_mode": {
20
+ "name": "ipython",
21
+ "version": 3
22
+ },
23
+ "file_extension": ".py",
24
+ "mimetype": "text/x-python",
25
+ "name": "python",
26
+ "nbconvert_exporter": "python",
27
+ "pygments_lexer": "ipython3",
28
+ "version": "3.8.5"
29
+ }
30
+ },
31
+ "nbformat": 4,
32
+ "nbformat_minor": 5
33
+ }
notebooks/comparisons.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:f915ebb630ffd80319041ec728d9c7123b821d8f96b4745e909e937213832d21
3
+ size 11079466
notebooks/daedra.ipynb.amltmp ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "source": [
6
+ "# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
7
+ "\n",
8
+ "DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
9
+ ],
10
+ "metadata": {}
11
+ },
12
+ {
13
+ "cell_type": "code",
14
+ "source": [
15
+ "%pip install accelerate -U"
16
+ ],
17
+ "outputs": [
18
+ {
19
+ "output_type": "stream",
20
+ "name": "stdout",
21
+ "text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
22
+ }
23
+ ],
24
+ "execution_count": 1,
25
+ "metadata": {
26
+ "gather": {
27
+ "logged": 1706475754655
28
+ },
29
+ "nteract": {
30
+ "transient": {
31
+ "deleting": false
32
+ }
33
+ },
34
+ "tags": []
35
+ }
36
+ },
37
+ {
38
+ "cell_type": "code",
39
+ "source": [
40
+ "%pip install transformers datasets shap watermark wandb evaluate codecarbon"
41
+ ],
42
+ "outputs": [
43
+ {
44
+ "output_type": "stream",
45
+ "name": "stdout",
46
+ "text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
47
+ }
48
+ ],
49
+ "execution_count": 2,
50
+ "metadata": {
51
+ "nteract": {
52
+ "transient": {
53
+ "deleting": false
54
+ }
55
+ }
56
+ }
57
+ },
58
+ {
59
+ "cell_type": "code",
60
+ "source": [
61
+ "import pandas as pd\n",
62
+ "import numpy as np\n",
63
+ "import torch\n",
64
+ "import os\n",
65
+ "from typing import List, Union\n",
66
+ "from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
67
+ "from datasets import load_dataset, Dataset, DatasetDict\n",
68
+ "import shap\n",
69
+ "import wandb\n",
70
+ "import evaluate\n",
71
+ "import logging\n",
72
+ "\n",
73
+ "wandb.finish()\n",
74
+ "\n",
75
+ "\n",
76
+ "os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
77
+ "\n",
78
+ "%load_ext watermark"
79
+ ],
80
+ "outputs": [
81
+ {
82
+ "output_type": "stream",
83
+ "name": "stderr",
84
+ "text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 17:46:15.020290: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 17:46:16.031641: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031779: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
85
+ }
86
+ ],
87
+ "execution_count": 3,
88
+ "metadata": {
89
+ "datalore": {
90
+ "hide_input_from_viewers": false,
91
+ "hide_output_from_viewers": false,
92
+ "node_id": "caZjjFP0OyQNMVgZDiwswE",
93
+ "report_properties": {
94
+ "rowId": "un8W7ez7ZwoGb5Co6nydEV"
95
+ },
96
+ "type": "CODE"
97
+ },
98
+ "gather": {
99
+ "logged": 1706550378660
100
+ },
101
+ "tags": []
102
+ }
103
+ },
104
+ {
105
+ "cell_type": "code",
106
+ "source": [
107
+ "device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
108
+ "\n",
109
+ "SEED: int = 42\n",
110
+ "\n",
111
+ "BATCH_SIZE: int = 32\n",
112
+ "EPOCHS: int = 5\n",
113
+ "model_ckpt: str = \"distilbert-base-uncased\"\n",
114
+ "\n",
115
+ "# WandB configuration\n",
116
+ "os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
117
+ "os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
118
+ "os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
119
+ ],
120
+ "outputs": [],
121
+ "execution_count": 4,
122
+ "metadata": {
123
+ "collapsed": false,
124
+ "gather": {
125
+ "logged": 1706550378812
126
+ },
127
+ "jupyter": {
128
+ "outputs_hidden": false
129
+ }
130
+ }
131
+ },
132
+ {
133
+ "cell_type": "code",
134
+ "source": [
135
+ "%watermark --iversion"
136
+ ],
137
+ "outputs": [
138
+ {
139
+ "output_type": "stream",
140
+ "name": "stdout",
141
+ "text": "shap : 0.44.1\npandas : 2.0.2\nwandb : 0.16.2\nre : 2.2.1\nevaluate: 0.4.1\ntorch : 1.12.0\nnumpy : 1.23.5\nlogging : 0.5.1.2\n\n"
142
+ }
143
+ ],
144
+ "execution_count": 5,
145
+ "metadata": {
146
+ "collapsed": false,
147
+ "jupyter": {
148
+ "outputs_hidden": false
149
+ }
150
+ }
151
+ },
152
+ {
153
+ "cell_type": "code",
154
+ "source": [
155
+ "!nvidia-smi"
156
+ ],
157
+ "outputs": [
158
+ {
159
+ "output_type": "stream",
160
+ "name": "stdout",
161
+ "text": "Mon Jan 29 17:46:18 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
162
+ }
163
+ ],
164
+ "execution_count": 6,
165
+ "metadata": {
166
+ "datalore": {
167
+ "hide_input_from_viewers": true,
168
+ "hide_output_from_viewers": true,
169
+ "node_id": "UU2oOJhwbIualogG1YyCMd",
170
+ "type": "CODE"
171
+ }
172
+ }
173
+ },
174
+ {
175
+ "cell_type": "markdown",
176
+ "source": [
177
+ "## Loading the data set"
178
+ ],
179
+ "metadata": {
180
+ "datalore": {
181
+ "hide_input_from_viewers": false,
182
+ "hide_output_from_viewers": false,
183
+ "node_id": "t45KHugmcPVaO0nuk8tGJ9",
184
+ "report_properties": {
185
+ "rowId": "40nN9Hvgi1clHNV5RAemI5"
186
+ },
187
+ "type": "MD"
188
+ }
189
+ }
190
+ },
191
+ {
192
+ "cell_type": "code",
193
+ "source": [
194
+ "dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
195
+ ],
196
+ "outputs": [],
197
+ "execution_count": 7,
198
+ "metadata": {
199
+ "collapsed": false,
200
+ "gather": {
201
+ "logged": 1706550381141
202
+ },
203
+ "jupyter": {
204
+ "outputs_hidden": false
205
+ }
206
+ }
207
+ },
208
+ {
209
+ "cell_type": "code",
210
+ "source": [
211
+ "dataset"
212
+ ],
213
+ "outputs": [
214
+ {
215
+ "output_type": "execute_result",
216
+ "execution_count": 8,
217
+ "data": {
218
+ "text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
219
+ },
220
+ "metadata": {}
221
+ }
222
+ ],
223
+ "execution_count": 8,
224
+ "metadata": {
225
+ "collapsed": false,
226
+ "gather": {
227
+ "logged": 1706550381303
228
+ },
229
+ "jupyter": {
230
+ "outputs_hidden": false,
231
+ "source_hidden": false
232
+ },
233
+ "nteract": {
234
+ "transient": {
235
+ "deleting": false
236
+ }
237
+ }
238
+ }
239
+ },
240
+ {
241
+ "cell_type": "code",
242
+ "source": [
243
+ "SUBSAMPLING = 0.01\n",
244
+ "\n",
245
+ "if SUBSAMPLING < 1:\n",
246
+ " _ = DatasetDict()\n",
247
+ " for each in dataset.keys():\n",
248
+ " _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
249
+ "\n",
250
+ " dataset = _"
251
+ ],
252
+ "outputs": [],
253
+ "execution_count": 9,
254
+ "metadata": {
255
+ "gather": {
256
+ "logged": 1706550381472
257
+ }
258
+ }
259
+ },
260
+ {
261
+ "cell_type": "markdown",
262
+ "source": [
263
+ "## Tokenisation and encoding"
264
+ ],
265
+ "metadata": {}
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "source": [
270
+ "def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
271
+ " return ds_enc"
272
+ ],
273
+ "outputs": [],
274
+ "execution_count": 10,
275
+ "metadata": {
276
+ "gather": {
277
+ "logged": 1706550381637
278
+ }
279
+ }
280
+ },
281
+ {
282
+ "cell_type": "markdown",
283
+ "source": [
284
+ "## Evaluation metrics"
285
+ ],
286
+ "metadata": {}
287
+ },
288
+ {
289
+ "cell_type": "code",
290
+ "source": [
291
+ "accuracy = evaluate.load(\"accuracy\")\n",
292
+ "precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
293
+ "f1 = evaluate.load(\"f1\")"
294
+ ],
295
+ "outputs": [],
296
+ "execution_count": 11,
297
+ "metadata": {
298
+ "gather": {
299
+ "logged": 1706550381778
300
+ }
301
+ }
302
+ },
303
+ {
304
+ "cell_type": "code",
305
+ "source": [
306
+ "def compute_metrics(eval_pred):\n",
307
+ " predictions, labels = eval_pred\n",
308
+ " predictions = np.argmax(predictions, axis=1)\n",
309
+ " return {\n",
310
+ " 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
311
+ " 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
312
+ " 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
313
+ " 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
314
+ " 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
315
+ " 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
316
+ " }"
317
+ ],
318
+ "outputs": [],
319
+ "execution_count": 12,
320
+ "metadata": {
321
+ "gather": {
322
+ "logged": 1706550381891
323
+ }
324
+ }
325
+ },
326
+ {
327
+ "cell_type": "markdown",
328
+ "source": [
329
+ "## Training"
330
+ ],
331
+ "metadata": {}
332
+ },
333
+ {
334
+ "cell_type": "markdown",
335
+ "source": [
336
+ "We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
337
+ ],
338
+ "metadata": {}
339
+ },
340
+ {
341
+ "cell_type": "code",
342
+ "source": [
343
+ "label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
344
+ ],
345
+ "outputs": [],
346
+ "execution_count": 13,
347
+ "metadata": {
348
+ "gather": {
349
+ "logged": 1706550382032
350
+ }
351
+ }
352
+ },
353
+ {
354
+ "cell_type": "code",
355
+ "source": [
356
+ "def train_from_model(model_ckpt: str, push: bool = False):\n",
357
+ " print(f\"Initialising training based on {model_ckpt}...\")\n",
358
+ "\n",
359
+ " print(\"Tokenising...\")\n",
360
+ " tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
361
+ "\n",
362
+ " cols = dataset[\"train\"].column_names\n",
363
+ " cols.remove(\"label\")\n",
364
+ " ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
365
+ "\n",
366
+ " print(\"Loading model...\")\n",
367
+ " try:\n",
368
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
369
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
370
+ " id2label=label_map, \n",
371
+ " label2id={v:k for k,v in label_map.items()})\n",
372
+ " except OSError:\n",
373
+ " model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
374
+ " num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
375
+ " id2label=label_map, \n",
376
+ " label2id={v:k for k,v in label_map.items()},\n",
377
+ " from_tf=True)\n",
378
+ "\n",
379
+ "\n",
380
+ " args = TrainingArguments(\n",
381
+ " output_dir=\"vaers\",\n",
382
+ " evaluation_strategy=\"epoch\",\n",
383
+ " save_strategy=\"epoch\",\n",
384
+ " learning_rate=2e-5,\n",
385
+ " per_device_train_batch_size=BATCH_SIZE,\n",
386
+ " per_device_eval_batch_size=BATCH_SIZE,\n",
387
+ " num_train_epochs=EPOCHS,\n",
388
+ " weight_decay=.01,\n",
389
+ " logging_steps=1,\n",
390
+ " load_best_model_at_end=True,\n",
391
+ " run_name=f\"daedra-training\",\n",
392
+ " report_to=[\"wandb\"])\n",
393
+ "\n",
394
+ " trainer = Trainer(\n",
395
+ " model=model,\n",
396
+ " args=args,\n",
397
+ " train_dataset=ds_enc[\"train\"],\n",
398
+ " eval_dataset=ds_enc[\"test\"],\n",
399
+ " tokenizer=tokenizer,\n",
400
+ " compute_metrics=compute_metrics)\n",
401
+ " \n",
402
+ " if SUBSAMPLING != 1.0:\n",
403
+ " wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
404
+ " else:\n",
405
+ " wandb_tag: List[str] = [f\"full_sample\"]\n",
406
+ "\n",
407
+ " wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
408
+ " wandb_tag.append(f\"base:{model_ckpt}\")\n",
409
+ " \n",
410
+ " wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
411
+ "\n",
412
+ " print(\"Starting training...\")\n",
413
+ "\n",
414
+ " trainer.train()\n",
415
+ "\n",
416
+ " print(\"Training finished.\")\n",
417
+ "\n",
418
+ " if push:\n",
419
+ " variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
420
+ " tokenizer._tokenizer.save(\"tokenizer.json\")\n",
421
+ " tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
422
+ " sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
423
+ "\n",
424
+ " model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
425
+ " variant=variant,\n",
426
+ " commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
427
+ ],
428
+ "outputs": [],
429
+ "execution_count": 14,
430
+ "metadata": {
431
+ "jupyter": {
432
+ "outputs_hidden": false,
433
+ "source_hidden": false
434
+ },
435
+ "nteract": {
436
+ "transient": {
437
+ "deleting": false
438
+ }
439
+ },
440
+ "gather": {
441
+ "logged": 1706550382160
442
+ }
443
+ }
444
+ },
445
+ {
446
+ "cell_type": "code",
447
+ "source": [
448
+ "\n",
449
+ "base_models = [\n",
450
+ " \"bert-base-uncased\",\n",
451
+ " \"distilbert-base-uncased\",\n",
452
+ "]"
453
+ ],
454
+ "outputs": [],
455
+ "execution_count": 15,
456
+ "metadata": {
457
+ "gather": {
458
+ "logged": 1706550382318
459
+ }
460
+ }
461
+ },
462
+ {
463
+ "cell_type": "code",
464
+ "source": [
465
+ "BATCH_SIZE=1\n",
466
+ "\n",
467
+ "train_from_model(\"biobert/Bio_ClinicalBERT/\")"
468
+ ],
469
+ "outputs": [
470
+ {
471
+ "output_type": "stream",
472
+ "name": "stdout",
473
+ "text": "Initialising training based on biobert/Bio_ClinicalBERT/...\nTokenising...\nLoading model...\n"
474
+ },
475
+ {
476
+ "output_type": "stream",
477
+ "name": "stderr",
478
+ "text": "Map: 100%|██████████| 2722/2722 [00:01<00:00, 2195.12 examples/s]\nAll TF 2.0 model weights were used when initializing BertForSequenceClassification.\n\nAll the weights of BertForSequenceClassification were initialized from the TF 2.0 model.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.\n"
479
+ },
480
+ {
481
+ "output_type": "display_data",
482
+ "data": {
483
+ "text/plain": "<IPython.core.display.HTML object>",
484
+ "text/html": "Finishing last run (ID:sg022tqh) before initializing another..."
485
+ },
486
+ "metadata": {}
487
+ },
488
+ {
489
+ "output_type": "display_data",
490
+ "data": {
491
+ "text/plain": "<IPython.core.display.HTML object>",
492
+ "text/html": " View run <strong style=\"color:#cdcd00\">daedra_0.01-biobert/Bio_ClinicalBERT/</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
493
+ },
494
+ "metadata": {}
495
+ },
496
+ {
497
+ "output_type": "display_data",
498
+ "data": {
499
+ "text/plain": "<IPython.core.display.HTML object>",
500
+ "text/html": "Find logs at: <code>./wandb/run-20240129_174816-sg022tqh/logs</code>"
501
+ },
502
+ "metadata": {}
503
+ },
504
+ {
505
+ "output_type": "display_data",
506
+ "data": {
507
+ "text/plain": "<IPython.core.display.HTML object>",
508
+ "text/html": "Successfully finished last run (ID:sg022tqh). Initializing new run:<br/>"
509
+ },
510
+ "metadata": {}
511
+ },
512
+ {
513
+ "output_type": "display_data",
514
+ "data": {
515
+ "text/plain": "<IPython.core.display.HTML object>",
516
+ "text/html": "Tracking run with wandb version 0.16.2"
517
+ },
518
+ "metadata": {}
519
+ },
520
+ {
521
+ "output_type": "display_data",
522
+ "data": {
523
+ "text/plain": "<IPython.core.display.HTML object>",
524
+ "text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_174936-kilkkg1j</code>"
525
+ },
526
+ "metadata": {}
527
+ },
528
+ {
529
+ "output_type": "display_data",
530
+ "data": {
531
+ "text/plain": "<IPython.core.display.HTML object>",
532
+ "text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">daedra_0.01-biobert/Bio_ClinicalBERT/</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
533
+ },
534
+ "metadata": {}
535
+ },
536
+ {
537
+ "output_type": "display_data",
538
+ "data": {
539
+ "text/plain": "<IPython.core.display.HTML object>",
540
+ "text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
541
+ },
542
+ "metadata": {}
543
+ },
544
+ {
545
+ "output_type": "display_data",
546
+ "data": {
547
+ "text/plain": "<IPython.core.display.HTML object>",
548
+ "text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j</a>"
549
+ },
550
+ "metadata": {}
551
+ },
552
+ {
553
+ "output_type": "stream",
554
+ "name": "stdout",
555
+ "text": "Starting training...\n"
556
+ },
557
+ {
558
+ "output_type": "stream",
559
+ "name": "stderr",
560
+ "text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
561
+ },
562
+ {
563
+ "output_type": "display_data",
564
+ "data": {
565
+ "text/plain": "<IPython.core.display.HTML object>",
566
+ "text/html": "\n <div>\n \n <progress value='1496' max='15880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1496/15880 07:43 < 1:14:19, 3.23 it/s, Epoch 0.47/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
567
+ },
568
+ "metadata": {}
569
+ }
570
+ ],
571
+ "execution_count": 21,
572
+ "metadata": {
573
+ "gather": {
574
+ "logged": 1706551053473
575
+ }
576
+ }
577
+ },
578
+ {
579
+ "cell_type": "code",
580
+ "source": [],
581
+ "outputs": [],
582
+ "execution_count": null,
583
+ "metadata": {
584
+ "jupyter": {
585
+ "source_hidden": false,
586
+ "outputs_hidden": false
587
+ },
588
+ "nteract": {
589
+ "transient": {
590
+ "deleting": false
591
+ }
592
+ }
593
+ }
594
+ }
595
+ ],
596
+ "metadata": {
597
+ "datalore": {
598
+ "base_environment": "default",
599
+ "computation_mode": "JUPYTER",
600
+ "package_manager": "pip",
601
+ "packages": [
602
+ {
603
+ "name": "datasets",
604
+ "source": "PIP",
605
+ "version": "2.16.1"
606
+ },
607
+ {
608
+ "name": "torch",
609
+ "source": "PIP",
610
+ "version": "2.1.2"
611
+ },
612
+ {
613
+ "name": "accelerate",
614
+ "source": "PIP",
615
+ "version": "0.26.1"
616
+ }
617
+ ],
618
+ "report_row_ids": [
619
+ "un8W7ez7ZwoGb5Co6nydEV",
620
+ "40nN9Hvgi1clHNV5RAemI5",
621
+ "TgRD90H5NSPpKS41OeXI1w",
622
+ "ZOm5BfUs3h1EGLaUkBGeEB",
623
+ "kOP0CZWNSk6vqE3wkPp7Vc",
624
+ "W4PWcOu2O2pRaZyoE2W80h",
625
+ "RolbOnQLIftk0vy9mIcz5M",
626
+ "8OPhUgbaNJmOdiq5D3a6vK",
627
+ "5Qrt3jSvSrpK6Ne1hS6shL",
628
+ "hTq7nFUrovN5Ao4u6dIYWZ",
629
+ "I8WNZLpJ1DVP2wiCW7YBIB",
630
+ "SawhU3I9BewSE1XBPstpNJ",
631
+ "80EtLEl2FIE4FqbWnUD3nT"
632
+ ],
633
+ "version": 3
634
+ },
635
+ "kernel_info": {
636
+ "name": "python38-azureml-pt-tf"
637
+ },
638
+ "kernelspec": {
639
+ "display_name": "azureml_py38_PT_TF",
640
+ "language": "python",
641
+ "name": "python3"
642
+ },
643
+ "language_info": {
644
+ "name": "python",
645
+ "version": "3.8.5",
646
+ "mimetype": "text/x-python",
647
+ "codemirror_mode": {
648
+ "name": "ipython",
649
+ "version": 3
650
+ },
651
+ "pygments_lexer": "ipython3",
652
+ "nbconvert_exporter": "python",
653
+ "file_extension": ".py"
654
+ },
655
+ "microsoft": {
656
+ "host": {
657
+ "AzureML": {
658
+ "notebookHasBeenCompleted": true
659
+ }
660
+ },
661
+ "ms_spell_check": {
662
+ "ms_spell_check_language": "en"
663
+ }
664
+ },
665
+ "nteract": {
666
+ "version": "nteract-front-end@1.0.0"
667
+ }
668
+ },
669
+ "nbformat": 4,
670
+ "nbformat_minor": 4
671
+ }
notebooks/daedra.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ from typing import List, Union
6
+ from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel
7
+ from datasets import load_dataset, Dataset, DatasetDict
8
+ import shap
9
+ import wandb
10
+ import evaluate
11
+ import logging
12
+
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ SEED: int = 42
18
+
19
+ BATCH_SIZE: int = 16
20
+ EPOCHS: int = 3
21
+ SUBSAMPLING: float = 0.1
22
+
23
+ # WandB configuration
24
+ os.environ["WANDB_PROJECT"] = "DAEDRA multiclass model training"
25
+ os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
26
+ os.environ["WANDB_NOTEBOOK_NAME"] = "DAEDRA.ipynb"
27
+
28
+ dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes")
29
+
30
+ if SUBSAMPLING < 1:
31
+ _ = DatasetDict()
32
+ for each in dataset.keys():
33
+ _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))
34
+
35
+ dataset = _
36
+
37
+ accuracy = evaluate.load("accuracy")
38
+ precision, recall = evaluate.load("precision"), evaluate.load("recall")
39
+ f1 = evaluate.load("f1")
40
+
41
+ def compute_metrics(eval_pred):
42
+ predictions, labels = eval_pred
43
+ predictions = np.argmax(predictions, axis=1)
44
+ return {
45
+ 'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"],
46
+ 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"],
47
+ 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"],
48
+ 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"],
49
+ 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"],
50
+ 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"]
51
+ }
52
+
53
+ label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)}
54
+
55
+ def train_from_model(model_ckpt: str, push: bool = False):
56
+ print(f"Initialising training based on {model_ckpt}...")
57
+
58
+ print("Tokenising...")
59
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
60
+
61
+ cols = dataset["train"].column_names
62
+ cols.remove("label")
63
+ ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols)
64
+
65
+ print("Loading model...")
66
+ try:
67
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
68
+ num_labels=len(dataset["test"].features["label"].names),
69
+ id2label=label_map,
70
+ label2id={v:k for k,v in label_map.items()})
71
+ except OSError:
72
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
73
+ num_labels=len(dataset["test"].features["label"].names),
74
+ id2label=label_map,
75
+ label2id={v:k for k,v in label_map.items()},
76
+ from_tf=True)
77
+
78
+
79
+ args = TrainingArguments(
80
+ output_dir="vaers",
81
+ evaluation_strategy="steps",
82
+ eval_steps=100,
83
+ save_strategy="epoch",
84
+ learning_rate=2e-5,
85
+ per_device_train_batch_size=BATCH_SIZE,
86
+ per_device_eval_batch_size=BATCH_SIZE,
87
+ num_train_epochs=EPOCHS,
88
+ weight_decay=.01,
89
+ logging_steps=1,
90
+ run_name=f"daedra-minisample-comparison-{SUBSAMPLING}",
91
+ report_to=["wandb"])
92
+
93
+ trainer = Trainer(
94
+ model=model,
95
+ args=args,
96
+ train_dataset=ds_enc["train"],
97
+ eval_dataset=ds_enc["test"],
98
+ tokenizer=tokenizer,
99
+ compute_metrics=compute_metrics)
100
+
101
+ if SUBSAMPLING != 1.0:
102
+ wandb_tag: List[str] = [f"subsample-{SUBSAMPLING}"]
103
+ else:
104
+ wandb_tag: List[str] = [f"full_sample"]
105
+
106
+ wandb_tag.append(f"batch_size-{BATCH_SIZE}")
107
+ wandb_tag.append(f"base:{model_ckpt}")
108
+
109
+ if "/" in model_ckpt:
110
+ sanitised_model_name = model_ckpt.split("/")[1]
111
+ else:
112
+ sanitised_model_name = model_ckpt
113
+
114
+ wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True)
115
+
116
+ print("Starting training...")
117
+
118
+ trainer.train()
119
+
120
+ print("Training finished.")
121
+
122
+ wandb.finish()
123
+
124
+ if __name__ == "__main__":
125
+ wandb.finish()
126
+
127
+ for mname in (
128
+ #"dmis-lab/biobert-base-cased-v1.2",
129
+ "emilyalsentzer/Bio_ClinicalBERT",
130
+ "bert-base-uncased",
131
+ "distilbert-base-uncased"
132
+ ):
133
+ print(f"Now training on subsample with {mname}...")
134
+ train_from_model(mname)
notebooks/daedra.py.amltmp ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ from typing import List, Union
6
+ from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel
7
+ from datasets import load_dataset, Dataset, DatasetDict
8
+ import shap
9
+ import wandb
10
+ import evaluate
11
+ import logging
12
+
13
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
14
+
15
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
16
+
17
+ SEED: int = 42
18
+
19
+ BATCH_SIZE: int = 16
20
+ EPOCHS: int = 3
21
+ SUBSAMPLING: float = 0.1
22
+
23
+ # WandB configuration
24
+ os.environ["WANDB_PROJECT"] = "DAEDRA multiclass model training"
25
+ os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
26
+ os.environ["WANDB_NOTEBOOK_NAME"] = "DAEDRA.ipynb"
27
+
28
+ dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes")
29
+
30
+ if SUBSAMPLING < 1:
31
+ _ = DatasetDict()
32
+ for each in dataset.keys():
33
+ _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))
34
+
35
+ dataset = _
36
+
37
+ accuracy = evaluate.load("accuracy")
38
+ precision, recall = evaluate.load("precision"), evaluate.load("recall")
39
+ f1 = evaluate.load("f1")
40
+
41
+ def compute_metrics(eval_pred):
42
+ predictions, labels = eval_pred
43
+ predictions = np.argmax(predictions, axis=1)
44
+ return {
45
+ 'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"],
46
+ 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"],
47
+ 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"],
48
+ 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"],
49
+ 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"],
50
+ 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"]
51
+ }
52
+
53
+ label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)}
54
+
55
+ def train_from_model(model_ckpt: str, push: bool = False):
56
+ print(f"Initialising training based on {model_ckpt}...")
57
+
58
+ print("Tokenising...")
59
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
60
+
61
+ cols = dataset["train"].column_names
62
+ cols.remove("label")
63
+ ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols)
64
+
65
+ print("Loading model...")
66
+ try:
67
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
68
+ num_labels=len(dataset["test"].features["label"].names),
69
+ id2label=label_map,
70
+ label2id={v:k for k,v in label_map.items()})
71
+ except OSError:
72
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
73
+ num_labels=len(dataset["test"].features["label"].names),
74
+ id2label=label_map,
75
+ label2id={v:k for k,v in label_map.items()},
76
+ from_tf=True)
77
+
78
+
79
+ args = TrainingArguments(
80
+ output_dir="vaers",
81
+ evaluation_strategy="steps",
82
+ eval_steps=100,
83
+ save_strategy="epoch",
84
+ learning_rate=2e-5,
85
+ per_device_train_batch_size=BATCH_SIZE,
86
+ per_device_eval_batch_size=BATCH_SIZE,
87
+ num_train_epochs=EPOCHS,
88
+ weight_decay=.01,
89
+ logging_steps=1,
90
+ run_name=f"daedra-minisample-comparison-{SUBSAMPLING}",
91
+ report_to=["wandb"])
92
+
93
+ trainer = Trainer(
94
+ model=model,
95
+ args=args,
96
+ train_dataset=ds_enc["train"],
97
+ eval_dataset=ds_enc["test"],
98
+ tokenizer=tokenizer,
99
+ compute_metrics=compute_metrics)
100
+
101
+ if SUBSAMPLING != 1.0:
102
+ wandb_tag: List[str] = [f"subsample-{SUBSAMPLING}"]
103
+ else:
104
+ wandb_tag: List[str] = [f"full_sample"]
105
+
106
+ wandb_tag.append(f"batch_size-{BATCH_SIZE}")
107
+ wandb_tag.append(f"base:{model_ckpt}")
108
+
109
+ if "/" in model_ckpt:
110
+ sanitised_model_name = model_ckpt.split("/")[1]
111
+ else:
112
+ sanitised_model_name = model_ckpt
113
+
114
+ wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True)
115
+
116
+ print("Starting training...")
117
+
118
+ trainer.train()
119
+
120
+ print("Training finished.")
121
+
122
+ wandb.finish()
123
+
124
+ if __name__ == "__main__":
125
+ wandb.finish()
126
+
127
+ for mname in (
128
+ #"dmis-lab/biobert-base-cased-v1.2",
129
+ "emilyalsentzer/Bio_ClinicalBERT",
130
+ "bert-base-uncased",
131
+ "distilbert-base-uncased"
132
+ ):
133
+ print(f"Now training on subsample with {mname}...")
134
+ train_from_model(mname)
notebooks/daedra_final_training.py.amltmp ADDED
@@ -0,0 +1,136 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import pandas as pd
2
+ import numpy as np
3
+ import torch
4
+ import os
5
+ from typing import List, Union
6
+ from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel
7
+ from datasets import load_dataset, Dataset, DatasetDict
8
+ import shap
9
+ import wandb
10
+ import evaluate
11
+ import logging
12
+ from codecarbon import EmissionsTracker
13
+
14
+
15
+ tracker = EmissionsTracker()
16
+
17
+ os.environ["TOKENIZERS_PARALLELISM"] = "false"
18
+
19
+ device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
20
+
21
+ SEED: int = 42
22
+
23
+ BATCH_SIZE: int = 16
24
+ EPOCHS: int = 3
25
+ SUBSAMPLING: float = 1
26
+
27
+ # WandB configuration
28
+ os.environ["WANDB_PROJECT"] = "DAEDRA final model training"
29
+ os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
30
+
31
+ dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes")
32
+
33
+ if SUBSAMPLING < 1:
34
+ _ = DatasetDict()
35
+ for each in dataset.keys():
36
+ _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))
37
+
38
+ dataset = _
39
+
40
+ accuracy = evaluate.load("accuracy")
41
+ precision, recall = evaluate.load("precision"), evaluate.load("recall")
42
+ f1 = evaluate.load("f1")
43
+
44
+ def compute_metrics(eval_pred):
45
+ predictions, labels = eval_pred
46
+ predictions = np.argmax(predictions, axis=1)
47
+ return {
48
+ 'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"],
49
+ 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"],
50
+ 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"],
51
+ 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"],
52
+ 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"],
53
+ 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"]
54
+ }
55
+
56
+ label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)}
57
+
58
+ def train_from_model(model_ckpt: str, push: bool = False):
59
+ print(f"Initialising training based on {model_ckpt}...")
60
+
61
+ print("Tokenising...")
62
+ tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
63
+
64
+ cols = dataset["train"].column_names
65
+ cols.remove("label")
66
+ ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols)
67
+
68
+ print("Loading model...")
69
+ try:
70
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
71
+ num_labels=len(dataset["test"].features["label"].names),
72
+ id2label=label_map,
73
+ label2id={v:k for k,v in label_map.items()})
74
+ except OSError:
75
+ model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
76
+ num_labels=len(dataset["test"].features["label"].names),
77
+ id2label=label_map,
78
+ label2id={v:k for k,v in label_map.items()},
79
+ from_tf=True)
80
+
81
+
82
+ args = TrainingArguments(
83
+ output_dir="daedra",
84
+ evaluation_strategy="steps",
85
+ eval_steps=1000,
86
+ save_steps=2000,
87
+ save_strategy="steps",
88
+ learning_rate=2e-5,
89
+ per_device_train_batch_size=BATCH_SIZE,
90
+ per_device_eval_batch_size=BATCH_SIZE,
91
+ num_train_epochs=EPOCHS,
92
+ weight_decay=.01,
93
+ logging_steps=1,
94
+ run_name=f"daedra-full-train",
95
+ report_to=["wandb", "codecarbon"],
96
+ save_total_limit=2,
97
+ load_best_model_at_end=True,
98
+ push_to_hub=True,
99
+ push_to_hub_model_id="daedra",
100
+ hub_strategy="every_save",
101
+ metric_for_best_model="f1_microaverage")
102
+
103
+ trainer = Trainer(
104
+ model=model,
105
+ args=args,
106
+ train_dataset=ds_enc["train"],
107
+ eval_dataset=ds_enc["test"],
108
+ tokenizer=tokenizer,
109
+ compute_metrics=compute_metrics)
110
+
111
+ wandb_tag: List[str] = ["full_sample"]
112
+
113
+ wandb_tag.append(f"batch_size-{BATCH_SIZE}")
114
+ wandb_tag.append(f"base:{model_ckpt}")
115
+
116
+ if "/" in model_ckpt:
117
+ sanitised_model_name = model_ckpt.split("/")[1]
118
+ else:
119
+ sanitised_model_name = model_ckpt
120
+
121
+ wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True)
122
+
123
+ print("Starting training...")
124
+
125
+ tracker.start()
126
+ trainer.train()
127
+ tracker.stop()
128
+
129
+ print("Training finished.")
130
+
131
+ wandb.finish()
132
+
133
+ if __name__ == "__main__":
134
+ wandb.finish()
135
+
136
+ train_from_model("dmis-lab/biobert-base-cased-v1.2")
notebooks/emissions.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ timestamp,project_name,run_id,duration,emissions,emissions_rate,cpu_power,gpu_power,ram_power,cpu_energy,gpu_energy,ram_energy,energy_consumed,country_name,country_iso_code,region,cloud_provider,cloud_region,os,python_version,codecarbon_version,cpu_count,cpu_model,gpu_count,gpu_model,longitude,latitude,ram_total_size,tracking_mode,on_cloud,pue
2
+ 2024-01-29T03:05:13,codecarbon,6bfec408-4fcc-427a-8e94-0cabc9332665,10637.685039520264,0.9110964852888171,8.564800348045516e-05,42.5,148.01654697980965,165.33123922348022,0.1255709600533049,1.8546672673437372,0.4879591008899785,2.468197328287024,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,2f,35031.49867606163,3.0209420758007166,8.623502247892816e-05,42.5,148.9286737800754,165.33123922348022,0.413520891640749,6.163406995166088,1.6069267112465608,8.183854598053408,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,440440.88330459594727,machine,N,1.0
3
+ 2024-01-29T14:29:57,codecarbon,484aeab5-8bdc-4fbc-8f66-0c204b0f2a.88330459594727,machine,N,1.0
notebooks/emissions.csv.amltmp ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ timestamp,project_name,run_id,duration,emissions,emissions_rate,cpu_power,gpu_power,ram_power,cpu_energy,gpu_energy,ram_energy,energy_consumed,country_name,country_iso_code,region,cloud_provider,cloud_region,os,python_version,codecarbon_version,cpu_count,cpu_model,gpu_count,gpu_model,longitude,latitude,ram_total_size,tracking_mode,on_cloud,pue
2
+ 2024-01-29T03:05:13,codecarbon,6bfec408-4fcc-427a-8e94-0cabc9332665,10637.685039520264,0.9110964852888171,8.564800348045516e-05,42.5,148.01654697980965,165.33123922348022,0.1255709600533049,1.8546672673437372,0.4879591008899785,2.468197328287024,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,2f,35031.49867606163,3.0209420758007166,8.623502247892816e-05,42.5,148.9286737800754,165.33123922348022,0.413520891640749,6.163406995166088,1.6069267112465608,8.183854598053408,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,440440.88330459594727,machine,N,1.0
3
+ 2024-01-29T14:29:57,codecarbon,484aeab5-8bdc-4fbc-8f66-0c204b0f2a.88330459594727,machine,N,1.0
notebooks/microsample_model_comparison.ipynb ADDED
File without changes
notebooks/tokenizer.json ADDED
The diff for this file is too large to render. See raw diff
 
notebooks/wandb/.amlignore ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
notebooks/wandb/.amlignore.amltmp ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ ## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
2
+ ## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
3
+
4
+ .ipynb_aml_checkpoints/
5
+ *.amltmp
6
+ *.amltemp
paper/.gitkeep ADDED
File without changes
tokenizer.json CHANGED
@@ -1,6 +1,11 @@
1
  {
2
  "version": "1.0",
3
- "truncation": null,
 
 
 
 
 
4
  "padding": null,
5
  "added_tokens": [
6
  {
 
1
  {
2
  "version": "1.0",
3
+ "truncation": {
4
+ "direction": "Right",
5
+ "max_length": 512,
6
+ "strategy": "LongestFirst",
7
+ "stride": 0
8
+ },
9
  "padding": null,
10
  "added_tokens": [
11
  {
training_args.bin CHANGED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:f5d7ea168393ffb21085d167d41727aa4ff418441e573afbcbbf468e3ccd8d1c
3
  size 4728
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c847edf58c0470c1a32d6d0f580f3c732e43c689025195de8e292f71fbb85be6
3
  size 4728