Training in progress, step 5000
Browse files- .amlignore +6 -0
- .amlignore.amltmp +6 -0
- .gitattributes +1 -0
- .gitignore +886 -0
- config.json +1 -1
- data/.amlignore +6 -0
- data/.amlignore.amltmp +6 -0
- data/.gitkeep +0 -0
- data/custom_vocab.txt +0 -0
- model.safetensors +2 -2
- notebooks/.amlignore +6 -0
- notebooks/.amlignore.amltmp +6 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-27-56Z.ipynb +744 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-52-4Z.ipynb +788 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-13-2-30Z.ipynb +1147 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-15-7-36Z.ipynb +1452 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-16-26-9Z.ipynb +1246 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-20-56-58Z.ipynb +993 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-23-54-39Z.ipynb +692 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-3-12-1Z.ipynb +1053 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-4-13-53Z.ipynb +0 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-14-26-30Z.ipynb +739 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-16-5-15Z.ipynb +729 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-17-44-52Z.ipynb +739 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-3-40-27Z.ipynb +1001 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-4-40-54Z.ipynb +1073 -0
- notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-30-21-44-8Z.ipynb +671 -0
- notebooks/.ipynb_aml_checkpoints/microsample_model_comparison-checkpoint2024-0-31-14-6-22Z.ipynb +0 -0
- notebooks/DAEDRA-Copy1.ipynb +1634 -0
- notebooks/DAEDRA.ipynb +671 -0
- notebooks/DAEDRA.yml +0 -0
- notebooks/Dataset preparation.ipynb +524 -0
- notebooks/Untitled.ipynb +33 -0
- notebooks/comparisons.csv +3 -0
- notebooks/daedra.ipynb.amltmp +671 -0
- notebooks/daedra.py +134 -0
- notebooks/daedra.py.amltmp +134 -0
- notebooks/daedra_final_training.py.amltmp +136 -0
- notebooks/emissions.csv +3 -0
- notebooks/emissions.csv.amltmp +3 -0
- notebooks/microsample_model_comparison.ipynb +0 -0
- notebooks/tokenizer.json +0 -0
- notebooks/wandb/.amlignore +6 -0
- notebooks/wandb/.amlignore.amltmp +6 -0
- paper/.gitkeep +0 -0
- tokenizer.json +6 -1
- training_args.bin +1 -1
.amlignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
.amlignore.amltmp
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
.gitattributes
CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
36 |
+
notebooks/comparisons.csv filter=lfs diff=lfs merge=lfs -text
|
.gitignore
ADDED
@@ -0,0 +1,886 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
### JetBrains template
|
2 |
+
# Covers JetBrains IDEs: IntelliJ, RubyMine, PhpStorm, AppCode, PyCharm, CLion, Android Studio, WebStorm and Rider
|
3 |
+
# Reference: https://intellij-support.jetbrains.com/hc/en-us/articles/206544839
|
4 |
+
|
5 |
+
# User-specific stuff
|
6 |
+
.idea/**/workspace.xml
|
7 |
+
.idea/**/tasks.xml
|
8 |
+
.idea/**/usage.statistics.xml
|
9 |
+
.idea/**/dictionaries
|
10 |
+
.idea/**/shelf
|
11 |
+
|
12 |
+
# Data folder
|
13 |
+
data/*.csv
|
14 |
+
|
15 |
+
# AWS User-specific
|
16 |
+
.idea/**/aws.xml
|
17 |
+
|
18 |
+
# Generated files
|
19 |
+
.idea/**/contentModel.xml
|
20 |
+
|
21 |
+
# Sensitive or high-churn files
|
22 |
+
.idea/**/dataSources/
|
23 |
+
.idea/**/dataSources.ids
|
24 |
+
.idea/**/dataSources.local.xml
|
25 |
+
.idea/**/sqlDataSources.xml
|
26 |
+
.idea/**/dynamic.xml
|
27 |
+
.idea/**/uiDesigner.xml
|
28 |
+
.idea/**/dbnavigator.xml
|
29 |
+
|
30 |
+
# Gradle
|
31 |
+
.idea/**/gradle.xml
|
32 |
+
.idea/**/libraries
|
33 |
+
|
34 |
+
# Gradle and Maven with auto-import
|
35 |
+
# When using Gradle or Maven with auto-import, you should exclude module files,
|
36 |
+
# since they will be recreated, and may cause churn. Uncomment if using
|
37 |
+
# auto-import.
|
38 |
+
# .idea/artifacts
|
39 |
+
# .idea/compiler.xml
|
40 |
+
# .idea/jarRepositories.xml
|
41 |
+
# .idea/modules.xml
|
42 |
+
# .idea/*.iml
|
43 |
+
# .idea/modules
|
44 |
+
# *.iml
|
45 |
+
# *.ipr
|
46 |
+
|
47 |
+
# CMake
|
48 |
+
cmake-build-*/
|
49 |
+
|
50 |
+
# Mongo Explorer plugin
|
51 |
+
.idea/**/mongoSettings.xml
|
52 |
+
|
53 |
+
# File-based project format
|
54 |
+
*.iws
|
55 |
+
|
56 |
+
# IntelliJ
|
57 |
+
out/
|
58 |
+
|
59 |
+
# mpeltonen/sbt-idea plugin
|
60 |
+
.idea_modules/
|
61 |
+
|
62 |
+
# JIRA plugin
|
63 |
+
atlassian-ide-plugin.xml
|
64 |
+
|
65 |
+
# Cursive Clojure plugin
|
66 |
+
.idea/replstate.xml
|
67 |
+
|
68 |
+
# SonarLint plugin
|
69 |
+
.idea/sonarlint/
|
70 |
+
|
71 |
+
# Crashlytics plugin (for Android Studio and IntelliJ)
|
72 |
+
com_crashlytics_export_strings.xml
|
73 |
+
crashlytics.properties
|
74 |
+
crashlytics-build.properties
|
75 |
+
fabric.properties
|
76 |
+
|
77 |
+
# Editor-based Rest Client
|
78 |
+
.idea/httpRequests
|
79 |
+
|
80 |
+
# Android studio 3.1+ serialized cache file
|
81 |
+
.idea/caches/build_file_checksums.ser
|
82 |
+
|
83 |
+
### OSX template
|
84 |
+
# General
|
85 |
+
.DS_Store
|
86 |
+
.AppleDouble
|
87 |
+
.LSOverride
|
88 |
+
|
89 |
+
# Icon must end with two \r
|
90 |
+
Icon
|
91 |
+
|
92 |
+
# Thumbnails
|
93 |
+
._*
|
94 |
+
|
95 |
+
# Files that might appear in the root of a volume
|
96 |
+
.DocumentRevisions-V100
|
97 |
+
.fseventsd
|
98 |
+
.Spotlight-V100
|
99 |
+
.TemporaryItems
|
100 |
+
.Trashes
|
101 |
+
.VolumeIcon.icns
|
102 |
+
.com.apple.timemachine.donotpresent
|
103 |
+
|
104 |
+
# Directories potentially created on remote AFP share
|
105 |
+
.AppleDB
|
106 |
+
.AppleDesktop
|
107 |
+
Network Trash Folder
|
108 |
+
Temporary Items
|
109 |
+
.apdisk
|
110 |
+
|
111 |
+
### TeX template
|
112 |
+
## Core latex/pdflatex auxiliary files:
|
113 |
+
*.aux
|
114 |
+
*.lof
|
115 |
+
*.log
|
116 |
+
*.lot
|
117 |
+
*.fls
|
118 |
+
*.out
|
119 |
+
*.toc
|
120 |
+
*.fmt
|
121 |
+
*.fot
|
122 |
+
*.cb
|
123 |
+
*.cb2
|
124 |
+
.*.lb
|
125 |
+
|
126 |
+
## Intermediate documents:
|
127 |
+
*.dvi
|
128 |
+
*.xdv
|
129 |
+
*-converted-to.*
|
130 |
+
# these rules might exclude image files for figures etc.
|
131 |
+
# *.ps
|
132 |
+
# *.eps
|
133 |
+
# *.pdf
|
134 |
+
|
135 |
+
## Generated if empty string is given at "Please type another file name for output:"
|
136 |
+
.pdf
|
137 |
+
|
138 |
+
## Bibliography auxiliary files (bibtex/biblatex/biber):
|
139 |
+
*.bbl
|
140 |
+
*.bcf
|
141 |
+
*.blg
|
142 |
+
*-blx.aux
|
143 |
+
*-blx.bib
|
144 |
+
*.run.xml
|
145 |
+
|
146 |
+
## Build tool auxiliary files:
|
147 |
+
*.fdb_latexmk
|
148 |
+
*.synctex
|
149 |
+
*.synctex(busy)
|
150 |
+
*.synctex.gz
|
151 |
+
*.synctex.gz(busy)
|
152 |
+
*.pdfsync
|
153 |
+
|
154 |
+
## Build tool directories for auxiliary files
|
155 |
+
# latexrun
|
156 |
+
latex.out/
|
157 |
+
|
158 |
+
## Auxiliary and intermediate files from other packages:
|
159 |
+
# algorithms
|
160 |
+
*.alg
|
161 |
+
*.loa
|
162 |
+
|
163 |
+
# achemso
|
164 |
+
acs-*.bib
|
165 |
+
|
166 |
+
# amsthm
|
167 |
+
*.thm
|
168 |
+
|
169 |
+
# beamer
|
170 |
+
*.nav
|
171 |
+
*.pre
|
172 |
+
*.snm
|
173 |
+
*.vrb
|
174 |
+
|
175 |
+
# changes
|
176 |
+
*.soc
|
177 |
+
|
178 |
+
# comment
|
179 |
+
*.cut
|
180 |
+
|
181 |
+
# cprotect
|
182 |
+
*.cpt
|
183 |
+
|
184 |
+
# elsarticle (documentclass of Elsevier journals)
|
185 |
+
*.spl
|
186 |
+
|
187 |
+
# endnotes
|
188 |
+
*.ent
|
189 |
+
|
190 |
+
*.lox
|
191 |
+
|
192 |
+
# feynmf/feynmp
|
193 |
+
*.mf
|
194 |
+
*.mp
|
195 |
+
*.t[1-9]
|
196 |
+
*.t[1-9][0-9]
|
197 |
+
*.tfm
|
198 |
+
|
199 |
+
#(r)(e)ledmac/(r)(e)ledpar
|
200 |
+
*.end
|
201 |
+
*.?end
|
202 |
+
*.[1-9]
|
203 |
+
*.[1-9][0-9]
|
204 |
+
*.[1-9][0-9][0-9]
|
205 |
+
*.[1-9]R
|
206 |
+
*.[1-9][0-9]R
|
207 |
+
*.[1-9][0-9][0-9]R
|
208 |
+
*.eledsec[1-9]
|
209 |
+
*.eledsec[1-9]R
|
210 |
+
*.eledsec[1-9][0-9]
|
211 |
+
*.eledsec[1-9][0-9]R
|
212 |
+
*.eledsec[1-9][0-9][0-9]
|
213 |
+
*.eledsec[1-9][0-9][0-9]R
|
214 |
+
|
215 |
+
# glossaries
|
216 |
+
*.acn
|
217 |
+
*.acr
|
218 |
+
*.glg
|
219 |
+
*.glo
|
220 |
+
*.gls
|
221 |
+
*.glsdefs
|
222 |
+
*.lzo
|
223 |
+
*.lzs
|
224 |
+
*.slg
|
225 |
+
*.slo
|
226 |
+
*.sls
|
227 |
+
|
228 |
+
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
|
229 |
+
# *.ist
|
230 |
+
|
231 |
+
# gnuplot
|
232 |
+
*.gnuplot
|
233 |
+
*.table
|
234 |
+
|
235 |
+
# gnuplottex
|
236 |
+
*-gnuplottex-*
|
237 |
+
|
238 |
+
# gregoriotex
|
239 |
+
*.gaux
|
240 |
+
*.glog
|
241 |
+
*.gtex
|
242 |
+
|
243 |
+
# htlatex
|
244 |
+
*.4ct
|
245 |
+
*.4tc
|
246 |
+
*.idv
|
247 |
+
*.lg
|
248 |
+
*.trc
|
249 |
+
*.xref
|
250 |
+
|
251 |
+
# hyperref
|
252 |
+
*.brf
|
253 |
+
|
254 |
+
# knitr
|
255 |
+
*-concordance.tex
|
256 |
+
# *.tikz
|
257 |
+
*-tikzDictionary
|
258 |
+
|
259 |
+
# listings
|
260 |
+
*.lol
|
261 |
+
|
262 |
+
# luatexja-ruby
|
263 |
+
*.ltjruby
|
264 |
+
|
265 |
+
# makeidx
|
266 |
+
*.idx
|
267 |
+
*.ilg
|
268 |
+
*.ind
|
269 |
+
|
270 |
+
# minitoc
|
271 |
+
*.maf
|
272 |
+
*.mlf
|
273 |
+
*.mlt
|
274 |
+
*.mtc[0-9]*
|
275 |
+
*.slf[0-9]*
|
276 |
+
*.slt[0-9]*
|
277 |
+
*.stc[0-9]*
|
278 |
+
|
279 |
+
# minted
|
280 |
+
_minted*
|
281 |
+
*.pyg
|
282 |
+
|
283 |
+
# morewrites
|
284 |
+
*.mw
|
285 |
+
|
286 |
+
# newpax
|
287 |
+
*.newpax
|
288 |
+
|
289 |
+
# nomencl
|
290 |
+
*.nlg
|
291 |
+
*.nlo
|
292 |
+
*.nls
|
293 |
+
|
294 |
+
# pax
|
295 |
+
*.pax
|
296 |
+
|
297 |
+
# pdfpcnotes
|
298 |
+
*.pdfpc
|
299 |
+
|
300 |
+
# sagetex
|
301 |
+
*.sagetex.sage
|
302 |
+
*.sagetex.py
|
303 |
+
*.sagetex.scmd
|
304 |
+
|
305 |
+
# scrwfile
|
306 |
+
*.wrt
|
307 |
+
|
308 |
+
# svg
|
309 |
+
svg-inkscape/
|
310 |
+
|
311 |
+
# sympy
|
312 |
+
*.sout
|
313 |
+
*.sympy
|
314 |
+
sympy-plots-for-*.tex/
|
315 |
+
|
316 |
+
# pdfcomment
|
317 |
+
*.upa
|
318 |
+
*.upb
|
319 |
+
|
320 |
+
# pythontex
|
321 |
+
*.pytxcode
|
322 |
+
pythontex-files-*/
|
323 |
+
|
324 |
+
# tcolorbox
|
325 |
+
*.listing
|
326 |
+
|
327 |
+
# thmtools
|
328 |
+
*.loe
|
329 |
+
|
330 |
+
# TikZ & PGF
|
331 |
+
*.dpth
|
332 |
+
*.md5
|
333 |
+
*.auxlock
|
334 |
+
|
335 |
+
# titletoc
|
336 |
+
*.ptc
|
337 |
+
|
338 |
+
# todonotes
|
339 |
+
*.tdo
|
340 |
+
|
341 |
+
# vhistory
|
342 |
+
*.hst
|
343 |
+
*.ver
|
344 |
+
|
345 |
+
*.lod
|
346 |
+
|
347 |
+
# xcolor
|
348 |
+
*.xcp
|
349 |
+
|
350 |
+
# xmpincl
|
351 |
+
*.xmpi
|
352 |
+
|
353 |
+
# xindy
|
354 |
+
*.xdy
|
355 |
+
|
356 |
+
# xypic precompiled matrices and outlines
|
357 |
+
*.xyc
|
358 |
+
*.xyd
|
359 |
+
|
360 |
+
# endfloat
|
361 |
+
*.ttt
|
362 |
+
*.fff
|
363 |
+
|
364 |
+
# Latexian
|
365 |
+
TSWLatexianTemp*
|
366 |
+
|
367 |
+
## Editors:
|
368 |
+
# WinEdt
|
369 |
+
*.bak
|
370 |
+
*.sav
|
371 |
+
|
372 |
+
# Texpad
|
373 |
+
.texpadtmp
|
374 |
+
|
375 |
+
# LyX
|
376 |
+
*.lyx~
|
377 |
+
|
378 |
+
# Kile
|
379 |
+
*.backup
|
380 |
+
|
381 |
+
# gummi
|
382 |
+
.*.swp
|
383 |
+
|
384 |
+
# KBibTeX
|
385 |
+
*~[0-9]*
|
386 |
+
|
387 |
+
# TeXnicCenter
|
388 |
+
*.tps
|
389 |
+
|
390 |
+
# auto folder when using emacs and auctex
|
391 |
+
./auto/*
|
392 |
+
*.el
|
393 |
+
|
394 |
+
# expex forward references with \gathertags
|
395 |
+
*-tags.tex
|
396 |
+
|
397 |
+
# standalone packages
|
398 |
+
*.sta
|
399 |
+
|
400 |
+
# Makeindex log files
|
401 |
+
*.lpz
|
402 |
+
|
403 |
+
# xwatermark package
|
404 |
+
*.xwm
|
405 |
+
|
406 |
+
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
|
407 |
+
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
|
408 |
+
# Uncomment the next line to have this generated file ignored.
|
409 |
+
#*Notes.bib
|
410 |
+
|
411 |
+
### JupyterNotebooks template
|
412 |
+
# gitignore template for Jupyter Notebooks
|
413 |
+
# website: http://jupyter.org/
|
414 |
+
|
415 |
+
.ipynb_checkpoints
|
416 |
+
*/.ipynb_checkpoints/*
|
417 |
+
|
418 |
+
# IPython
|
419 |
+
profile_default/
|
420 |
+
ipython_config.py
|
421 |
+
|
422 |
+
# Remove previous ipynb_checkpoints
|
423 |
+
# git rm -r .ipynb_checkpoints/
|
424 |
+
|
425 |
+
### LaTeX template
|
426 |
+
## Core latex/pdflatex auxiliary files:
|
427 |
+
*.aux
|
428 |
+
*.lof
|
429 |
+
*.log
|
430 |
+
*.lot
|
431 |
+
*.fls
|
432 |
+
*.out
|
433 |
+
*.toc
|
434 |
+
*.fmt
|
435 |
+
*.fot
|
436 |
+
*.cb
|
437 |
+
*.cb2
|
438 |
+
.*.lb
|
439 |
+
|
440 |
+
## Intermediate documents:
|
441 |
+
*.dvi
|
442 |
+
*.xdv
|
443 |
+
*-converted-to.*
|
444 |
+
# these rules might exclude image files for figures etc.
|
445 |
+
# *.ps
|
446 |
+
# *.eps
|
447 |
+
# *.pdf
|
448 |
+
|
449 |
+
## Generated if empty string is given at "Please type another file name for output:"
|
450 |
+
.pdf
|
451 |
+
|
452 |
+
## Bibliography auxiliary files (bibtex/biblatex/biber):
|
453 |
+
*.bbl
|
454 |
+
*.bcf
|
455 |
+
*.blg
|
456 |
+
*-blx.aux
|
457 |
+
*-blx.bib
|
458 |
+
*.run.xml
|
459 |
+
|
460 |
+
## Build tool auxiliary files:
|
461 |
+
*.fdb_latexmk
|
462 |
+
*.synctex
|
463 |
+
*.synctex(busy)
|
464 |
+
*.synctex.gz
|
465 |
+
*.synctex.gz(busy)
|
466 |
+
*.pdfsync
|
467 |
+
|
468 |
+
## Build tool directories for auxiliary files
|
469 |
+
# latexrun
|
470 |
+
latex.out/
|
471 |
+
|
472 |
+
## Auxiliary and intermediate files from other packages:
|
473 |
+
# algorithms
|
474 |
+
*.alg
|
475 |
+
*.loa
|
476 |
+
|
477 |
+
# achemso
|
478 |
+
acs-*.bib
|
479 |
+
|
480 |
+
# amsthm
|
481 |
+
*.thm
|
482 |
+
|
483 |
+
# beamer
|
484 |
+
*.nav
|
485 |
+
*.pre
|
486 |
+
*.snm
|
487 |
+
*.vrb
|
488 |
+
|
489 |
+
# changes
|
490 |
+
*.soc
|
491 |
+
|
492 |
+
# comment
|
493 |
+
*.cut
|
494 |
+
|
495 |
+
# cprotect
|
496 |
+
*.cpt
|
497 |
+
|
498 |
+
# elsarticle (documentclass of Elsevier journals)
|
499 |
+
*.spl
|
500 |
+
|
501 |
+
# endnotes
|
502 |
+
*.ent
|
503 |
+
|
504 |
+
*.lox
|
505 |
+
|
506 |
+
# feynmf/feynmp
|
507 |
+
*.mf
|
508 |
+
*.mp
|
509 |
+
*.t[1-9]
|
510 |
+
*.t[1-9][0-9]
|
511 |
+
*.tfm
|
512 |
+
|
513 |
+
#(r)(e)ledmac/(r)(e)ledpar
|
514 |
+
*.end
|
515 |
+
*.?end
|
516 |
+
*.[1-9]
|
517 |
+
*.[1-9][0-9]
|
518 |
+
*.[1-9][0-9][0-9]
|
519 |
+
*.[1-9]R
|
520 |
+
*.[1-9][0-9]R
|
521 |
+
*.[1-9][0-9][0-9]R
|
522 |
+
*.eledsec[1-9]
|
523 |
+
*.eledsec[1-9]R
|
524 |
+
*.eledsec[1-9][0-9]
|
525 |
+
*.eledsec[1-9][0-9]R
|
526 |
+
*.eledsec[1-9][0-9][0-9]
|
527 |
+
*.eledsec[1-9][0-9][0-9]R
|
528 |
+
|
529 |
+
# glossaries
|
530 |
+
*.acn
|
531 |
+
*.acr
|
532 |
+
*.glg
|
533 |
+
*.glo
|
534 |
+
*.gls
|
535 |
+
*.glsdefs
|
536 |
+
*.lzo
|
537 |
+
*.lzs
|
538 |
+
*.slg
|
539 |
+
*.slo
|
540 |
+
*.sls
|
541 |
+
|
542 |
+
# uncomment this for glossaries-extra (will ignore makeindex's style files!)
|
543 |
+
# *.ist
|
544 |
+
|
545 |
+
# gnuplot
|
546 |
+
*.gnuplot
|
547 |
+
*.table
|
548 |
+
|
549 |
+
# gnuplottex
|
550 |
+
*-gnuplottex-*
|
551 |
+
|
552 |
+
# gregoriotex
|
553 |
+
*.gaux
|
554 |
+
*.glog
|
555 |
+
*.gtex
|
556 |
+
|
557 |
+
# htlatex
|
558 |
+
*.4ct
|
559 |
+
*.4tc
|
560 |
+
*.idv
|
561 |
+
*.lg
|
562 |
+
*.trc
|
563 |
+
*.xref
|
564 |
+
|
565 |
+
# hyperref
|
566 |
+
*.brf
|
567 |
+
|
568 |
+
# knitr
|
569 |
+
*-concordance.tex
|
570 |
+
# *.tikz
|
571 |
+
*-tikzDictionary
|
572 |
+
|
573 |
+
# listings
|
574 |
+
*.lol
|
575 |
+
|
576 |
+
# luatexja-ruby
|
577 |
+
*.ltjruby
|
578 |
+
|
579 |
+
# makeidx
|
580 |
+
*.idx
|
581 |
+
*.ilg
|
582 |
+
*.ind
|
583 |
+
|
584 |
+
# minitoc
|
585 |
+
*.maf
|
586 |
+
*.mlf
|
587 |
+
*.mlt
|
588 |
+
*.mtc[0-9]*
|
589 |
+
*.slf[0-9]*
|
590 |
+
*.slt[0-9]*
|
591 |
+
*.stc[0-9]*
|
592 |
+
|
593 |
+
# minted
|
594 |
+
_minted*
|
595 |
+
*.pyg
|
596 |
+
|
597 |
+
# morewrites
|
598 |
+
*.mw
|
599 |
+
|
600 |
+
# newpax
|
601 |
+
*.newpax
|
602 |
+
|
603 |
+
# nomencl
|
604 |
+
*.nlg
|
605 |
+
*.nlo
|
606 |
+
*.nls
|
607 |
+
|
608 |
+
# pax
|
609 |
+
*.pax
|
610 |
+
|
611 |
+
# pdfpcnotes
|
612 |
+
*.pdfpc
|
613 |
+
|
614 |
+
# sagetex
|
615 |
+
*.sagetex.sage
|
616 |
+
*.sagetex.py
|
617 |
+
*.sagetex.scmd
|
618 |
+
|
619 |
+
# scrwfile
|
620 |
+
*.wrt
|
621 |
+
|
622 |
+
# svg
|
623 |
+
svg-inkscape/
|
624 |
+
|
625 |
+
# sympy
|
626 |
+
*.sout
|
627 |
+
*.sympy
|
628 |
+
sympy-plots-for-*.tex/
|
629 |
+
|
630 |
+
# pdfcomment
|
631 |
+
*.upa
|
632 |
+
*.upb
|
633 |
+
|
634 |
+
# pythontex
|
635 |
+
*.pytxcode
|
636 |
+
pythontex-files-*/
|
637 |
+
|
638 |
+
# tcolorbox
|
639 |
+
*.listing
|
640 |
+
|
641 |
+
# thmtools
|
642 |
+
*.loe
|
643 |
+
|
644 |
+
# TikZ & PGF
|
645 |
+
*.dpth
|
646 |
+
*.md5
|
647 |
+
*.auxlock
|
648 |
+
|
649 |
+
# titletoc
|
650 |
+
*.ptc
|
651 |
+
|
652 |
+
# todonotes
|
653 |
+
*.tdo
|
654 |
+
|
655 |
+
# vhistory
|
656 |
+
*.hst
|
657 |
+
*.ver
|
658 |
+
|
659 |
+
*.lod
|
660 |
+
|
661 |
+
# xcolor
|
662 |
+
*.xcp
|
663 |
+
|
664 |
+
# xmpincl
|
665 |
+
*.xmpi
|
666 |
+
|
667 |
+
# xindy
|
668 |
+
*.xdy
|
669 |
+
|
670 |
+
# xypic precompiled matrices and outlines
|
671 |
+
*.xyc
|
672 |
+
*.xyd
|
673 |
+
|
674 |
+
# endfloat
|
675 |
+
*.ttt
|
676 |
+
*.fff
|
677 |
+
|
678 |
+
# Latexian
|
679 |
+
TSWLatexianTemp*
|
680 |
+
|
681 |
+
## Editors:
|
682 |
+
# WinEdt
|
683 |
+
*.bak
|
684 |
+
*.sav
|
685 |
+
|
686 |
+
# Texpad
|
687 |
+
.texpadtmp
|
688 |
+
|
689 |
+
# LyX
|
690 |
+
*.lyx~
|
691 |
+
|
692 |
+
# Kile
|
693 |
+
*.backup
|
694 |
+
|
695 |
+
# gummi
|
696 |
+
.*.swp
|
697 |
+
|
698 |
+
# KBibTeX
|
699 |
+
*~[0-9]*
|
700 |
+
|
701 |
+
# TeXnicCenter
|
702 |
+
*.tps
|
703 |
+
|
704 |
+
# auto folder when using emacs and auctex
|
705 |
+
./auto/*
|
706 |
+
*.el
|
707 |
+
|
708 |
+
# expex forward references with \gathertags
|
709 |
+
*-tags.tex
|
710 |
+
|
711 |
+
# standalone packages
|
712 |
+
*.sta
|
713 |
+
|
714 |
+
# Makeindex log files
|
715 |
+
*.lpz
|
716 |
+
|
717 |
+
# xwatermark package
|
718 |
+
*.xwm
|
719 |
+
|
720 |
+
# REVTeX puts footnotes in the bibliography by default, unless the nofootinbib
|
721 |
+
# option is specified. Footnotes are the stored in a file with suffix Notes.bib.
|
722 |
+
# Uncomment the next line to have this generated file ignored.
|
723 |
+
#*Notes.bib
|
724 |
+
|
725 |
+
### Python template
|
726 |
+
# Byte-compiled / optimized / DLL files
|
727 |
+
__pycache__/
|
728 |
+
*.py[cod]
|
729 |
+
*$py.class
|
730 |
+
|
731 |
+
# C extensions
|
732 |
+
*.so
|
733 |
+
|
734 |
+
# Distribution / packaging
|
735 |
+
.Python
|
736 |
+
build/
|
737 |
+
develop-eggs/
|
738 |
+
dist/
|
739 |
+
downloads/
|
740 |
+
eggs/
|
741 |
+
.eggs/
|
742 |
+
lib/
|
743 |
+
lib64/
|
744 |
+
parts/
|
745 |
+
sdist/
|
746 |
+
var/
|
747 |
+
wheels/
|
748 |
+
share/python-wheels/
|
749 |
+
*.egg-info/
|
750 |
+
.installed.cfg
|
751 |
+
*.egg
|
752 |
+
MANIFEST
|
753 |
+
|
754 |
+
# PyInstaller
|
755 |
+
# Usually these files are written by a python script from a template
|
756 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
757 |
+
*.manifest
|
758 |
+
*.spec
|
759 |
+
|
760 |
+
# Installer logs
|
761 |
+
pip-log.txt
|
762 |
+
pip-delete-this-directory.txt
|
763 |
+
|
764 |
+
# Unit test / coverage reports
|
765 |
+
htmlcov/
|
766 |
+
.tox/
|
767 |
+
.nox/
|
768 |
+
.coverage
|
769 |
+
.coverage.*
|
770 |
+
.cache
|
771 |
+
nosetests.xml
|
772 |
+
coverage.xml
|
773 |
+
*.cover
|
774 |
+
*.py,cover
|
775 |
+
.hypothesis/
|
776 |
+
.pytest_cache/
|
777 |
+
cover/
|
778 |
+
|
779 |
+
# Translations
|
780 |
+
*.mo
|
781 |
+
*.pot
|
782 |
+
|
783 |
+
# Django stuff:
|
784 |
+
*.log
|
785 |
+
local_settings.py
|
786 |
+
db.sqlite3
|
787 |
+
db.sqlite3-journal
|
788 |
+
|
789 |
+
# Flask stuff:
|
790 |
+
instance/
|
791 |
+
.webassets-cache
|
792 |
+
|
793 |
+
# Scrapy stuff:
|
794 |
+
.scrapy
|
795 |
+
|
796 |
+
# Sphinx documentation
|
797 |
+
docs/_build/
|
798 |
+
|
799 |
+
# PyBuilder
|
800 |
+
.pybuilder/
|
801 |
+
target/
|
802 |
+
|
803 |
+
# Jupyter Notebook
|
804 |
+
.ipynb_checkpoints
|
805 |
+
|
806 |
+
# IPython
|
807 |
+
profile_default/
|
808 |
+
ipython_config.py
|
809 |
+
|
810 |
+
# pyenv
|
811 |
+
# For a library or package, you might want to ignore these files since the code is
|
812 |
+
# intended to run in multiple environments; otherwise, check them in:
|
813 |
+
# .python-version
|
814 |
+
|
815 |
+
# pipenv
|
816 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
817 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
818 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
819 |
+
# install all needed dependencies.
|
820 |
+
#Pipfile.lock
|
821 |
+
|
822 |
+
# poetry
|
823 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
824 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
825 |
+
# commonly ignored for libraries.
|
826 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
827 |
+
#poetry.lock
|
828 |
+
|
829 |
+
# pdm
|
830 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
831 |
+
#pdm.lock
|
832 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
833 |
+
# in version control.
|
834 |
+
# https://pdm.fming.dev/#use-with-ide
|
835 |
+
.pdm.toml
|
836 |
+
|
837 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
838 |
+
__pypackages__/
|
839 |
+
|
840 |
+
# Celery stuff
|
841 |
+
celerybeat-schedule
|
842 |
+
celerybeat.pid
|
843 |
+
|
844 |
+
# SageMath parsed files
|
845 |
+
*.sage.py
|
846 |
+
|
847 |
+
# Environments
|
848 |
+
.env
|
849 |
+
.venv
|
850 |
+
env/
|
851 |
+
venv/
|
852 |
+
ENV/
|
853 |
+
env.bak/
|
854 |
+
venv.bak/
|
855 |
+
|
856 |
+
# Spyder project settings
|
857 |
+
.spyderproject
|
858 |
+
.spyproject
|
859 |
+
|
860 |
+
# Rope project settings
|
861 |
+
.ropeproject
|
862 |
+
|
863 |
+
# mkdocs documentation
|
864 |
+
/site
|
865 |
+
|
866 |
+
# mypy
|
867 |
+
.mypy_cache/
|
868 |
+
.dmypy.json
|
869 |
+
dmypy.json
|
870 |
+
|
871 |
+
# Pyre type checker
|
872 |
+
.pyre/
|
873 |
+
|
874 |
+
# pytype static type analyzer
|
875 |
+
.pytype/
|
876 |
+
|
877 |
+
# Cython debug symbols
|
878 |
+
cython_debug/
|
879 |
+
|
880 |
+
# PyCharm
|
881 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
882 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
883 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
884 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
885 |
+
#.idea/
|
886 |
+
|
config.json
CHANGED
@@ -42,5 +42,5 @@
|
|
42 |
"transformers_version": "4.37.2",
|
43 |
"type_vocab_size": 2,
|
44 |
"use_cache": true,
|
45 |
-
"vocab_size":
|
46 |
}
|
|
|
42 |
"transformers_version": "4.37.2",
|
43 |
"type_vocab_size": 2,
|
44 |
"use_cache": true,
|
45 |
+
"vocab_size": 52000
|
46 |
}
|
data/.amlignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
data/.amlignore.amltmp
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
data/.gitkeep
ADDED
File without changes
|
data/custom_vocab.txt
ADDED
File without changes
|
model.safetensors
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
-
size
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:fd4f0f89ac5e5fa87847f7574a0c4343175b23afd37140f15173824a44dfee61
|
3 |
+
size 503957528
|
notebooks/.amlignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
notebooks/.amlignore.amltmp
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-27-56Z.ipynb
ADDED
@@ -0,0 +1,744 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {
|
11 |
+
"collapsed": false
|
12 |
+
}
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"source": [
|
17 |
+
"import pandas as pd\n",
|
18 |
+
"import numpy as np\n",
|
19 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
|
20 |
+
"import torch\n",
|
21 |
+
"import os\n",
|
22 |
+
"from typing import List\n",
|
23 |
+
"from datasets import load_dataset\n",
|
24 |
+
"import shap\n",
|
25 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
26 |
+
"\n",
|
27 |
+
"%load_ext watermark"
|
28 |
+
],
|
29 |
+
"execution_count": null,
|
30 |
+
"outputs": [],
|
31 |
+
"metadata": {
|
32 |
+
"datalore": {
|
33 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
34 |
+
"type": "CODE",
|
35 |
+
"hide_input_from_viewers": false,
|
36 |
+
"hide_output_from_viewers": false,
|
37 |
+
"report_properties": {
|
38 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
39 |
+
}
|
40 |
+
}
|
41 |
+
}
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"outputs": [],
|
46 |
+
"source": [
|
47 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
48 |
+
"\n",
|
49 |
+
"SEED: int = 42\n",
|
50 |
+
"\n",
|
51 |
+
"BATCH_SIZE: int = 8\n",
|
52 |
+
"EPOCHS: int = 1\n",
|
53 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
54 |
+
"\n",
|
55 |
+
"CLASS_NAMES: List[str] = [\"DIED\",\n",
|
56 |
+
" \"ER_VISIT\",\n",
|
57 |
+
" \"HOSPITAL\",\n",
|
58 |
+
" \"OFC_VISIT\",\n",
|
59 |
+
" \"X_STAY\",\n",
|
60 |
+
" \"DISABLE\",\n",
|
61 |
+
" \"D_PRESENTED\"]\n",
|
62 |
+
"\n",
|
63 |
+
"# WandB configuration\n",
|
64 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
|
65 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
|
66 |
+
],
|
67 |
+
"metadata": {
|
68 |
+
"collapsed": false
|
69 |
+
}
|
70 |
+
},
|
71 |
+
{
|
72 |
+
"cell_type": "code",
|
73 |
+
"outputs": [],
|
74 |
+
"source": [
|
75 |
+
"%watermark --iversion"
|
76 |
+
],
|
77 |
+
"metadata": {
|
78 |
+
"collapsed": false
|
79 |
+
}
|
80 |
+
},
|
81 |
+
{
|
82 |
+
"cell_type": "code",
|
83 |
+
"source": [
|
84 |
+
"!nvidia-smi"
|
85 |
+
],
|
86 |
+
"execution_count": null,
|
87 |
+
"outputs": [],
|
88 |
+
"metadata": {
|
89 |
+
"datalore": {
|
90 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
91 |
+
"type": "CODE",
|
92 |
+
"hide_input_from_viewers": true,
|
93 |
+
"hide_output_from_viewers": true
|
94 |
+
}
|
95 |
+
}
|
96 |
+
},
|
97 |
+
{
|
98 |
+
"cell_type": "markdown",
|
99 |
+
"source": [
|
100 |
+
"## Loading the data set"
|
101 |
+
],
|
102 |
+
"attachments": {},
|
103 |
+
"metadata": {
|
104 |
+
"datalore": {
|
105 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
106 |
+
"type": "MD",
|
107 |
+
"hide_input_from_viewers": false,
|
108 |
+
"hide_output_from_viewers": false,
|
109 |
+
"report_properties": {
|
110 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
111 |
+
}
|
112 |
+
}
|
113 |
+
}
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"cell_type": "code",
|
117 |
+
"outputs": [],
|
118 |
+
"source": [
|
119 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
120 |
+
],
|
121 |
+
"metadata": {
|
122 |
+
"collapsed": false
|
123 |
+
}
|
124 |
+
},
|
125 |
+
{
|
126 |
+
"cell_type": "markdown",
|
127 |
+
"source": [
|
128 |
+
"### Tokenisation and encoding"
|
129 |
+
],
|
130 |
+
"metadata": {
|
131 |
+
"collapsed": false
|
132 |
+
}
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"source": [
|
137 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
|
138 |
+
],
|
139 |
+
"execution_count": null,
|
140 |
+
"outputs": [],
|
141 |
+
"metadata": {
|
142 |
+
"datalore": {
|
143 |
+
"node_id": "I7n646PIscsUZRoHu6m7zm",
|
144 |
+
"type": "CODE",
|
145 |
+
"hide_input_from_viewers": true,
|
146 |
+
"hide_output_from_viewers": true
|
147 |
+
}
|
148 |
+
}
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"source": [
|
153 |
+
"def tokenize_and_encode(examples):\n",
|
154 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
155 |
+
],
|
156 |
+
"execution_count": null,
|
157 |
+
"outputs": [],
|
158 |
+
"metadata": {
|
159 |
+
"datalore": {
|
160 |
+
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
|
161 |
+
"type": "CODE",
|
162 |
+
"hide_input_from_viewers": true,
|
163 |
+
"hide_output_from_viewers": true
|
164 |
+
}
|
165 |
+
}
|
166 |
+
},
|
167 |
+
{
|
168 |
+
"cell_type": "code",
|
169 |
+
"source": [
|
170 |
+
"cols = dataset[\"train\"].column_names\n",
|
171 |
+
"cols.remove(\"labels\")\n",
|
172 |
+
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
173 |
+
],
|
174 |
+
"execution_count": null,
|
175 |
+
"outputs": [],
|
176 |
+
"metadata": {
|
177 |
+
"datalore": {
|
178 |
+
"node_id": "slHeNysZOX9uWS9PB7jFDb",
|
179 |
+
"type": "CODE",
|
180 |
+
"hide_input_from_viewers": true,
|
181 |
+
"hide_output_from_viewers": true
|
182 |
+
}
|
183 |
+
}
|
184 |
+
},
|
185 |
+
{
|
186 |
+
"cell_type": "markdown",
|
187 |
+
"source": [
|
188 |
+
"### Training"
|
189 |
+
],
|
190 |
+
"metadata": {
|
191 |
+
"collapsed": false
|
192 |
+
}
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "code",
|
196 |
+
"source": [
|
197 |
+
"class MultiLabelTrainer(Trainer):\n",
|
198 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
199 |
+
" labels = inputs.pop(\"labels\")\n",
|
200 |
+
" outputs = model(**inputs)\n",
|
201 |
+
" logits = outputs.logits\n",
|
202 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
203 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
204 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
205 |
+
" return (loss, outputs) if return_outputs else loss"
|
206 |
+
],
|
207 |
+
"execution_count": null,
|
208 |
+
"outputs": [],
|
209 |
+
"metadata": {
|
210 |
+
"datalore": {
|
211 |
+
"node_id": "itXWkbDw9sqbkMuDP84QoT",
|
212 |
+
"type": "CODE",
|
213 |
+
"hide_input_from_viewers": true,
|
214 |
+
"hide_output_from_viewers": true
|
215 |
+
}
|
216 |
+
}
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"source": [
|
221 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to(\"cuda\")"
|
222 |
+
],
|
223 |
+
"execution_count": null,
|
224 |
+
"outputs": [],
|
225 |
+
"metadata": {
|
226 |
+
"datalore": {
|
227 |
+
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
|
228 |
+
"type": "CODE",
|
229 |
+
"hide_input_from_viewers": true,
|
230 |
+
"hide_output_from_viewers": true
|
231 |
+
}
|
232 |
+
}
|
233 |
+
},
|
234 |
+
{
|
235 |
+
"cell_type": "code",
|
236 |
+
"source": [
|
237 |
+
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
|
238 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
239 |
+
" y_true = torch.from_numpy(y_true)\n",
|
240 |
+
"\n",
|
241 |
+
" if sigmoid:\n",
|
242 |
+
" y_pred = y_pred.sigmoid()\n",
|
243 |
+
"\n",
|
244 |
+
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
|
245 |
+
],
|
246 |
+
"execution_count": null,
|
247 |
+
"outputs": [],
|
248 |
+
"metadata": {
|
249 |
+
"datalore": {
|
250 |
+
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
|
251 |
+
"type": "CODE",
|
252 |
+
"hide_input_from_viewers": true,
|
253 |
+
"hide_output_from_viewers": true
|
254 |
+
}
|
255 |
+
}
|
256 |
+
},
|
257 |
+
{
|
258 |
+
"cell_type": "code",
|
259 |
+
"source": [
|
260 |
+
"def compute_metrics(eval_pred):\n",
|
261 |
+
" predictions, labels = eval_pred\n",
|
262 |
+
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
|
263 |
+
],
|
264 |
+
"execution_count": null,
|
265 |
+
"outputs": [],
|
266 |
+
"metadata": {
|
267 |
+
"datalore": {
|
268 |
+
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
|
269 |
+
"type": "CODE",
|
270 |
+
"hide_input_from_viewers": true,
|
271 |
+
"hide_output_from_viewers": true
|
272 |
+
}
|
273 |
+
}
|
274 |
+
},
|
275 |
+
{
|
276 |
+
"cell_type": "code",
|
277 |
+
"source": [
|
278 |
+
"args = TrainingArguments(\n",
|
279 |
+
" output_dir=\"vaers\",\n",
|
280 |
+
" evaluation_strategy=\"epoch\",\n",
|
281 |
+
" learning_rate=2e-5,\n",
|
282 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
283 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
284 |
+
" num_train_epochs=EPOCHS,\n",
|
285 |
+
" weight_decay=.01,\n",
|
286 |
+
" report_to=[\"wandb\"]\n",
|
287 |
+
")"
|
288 |
+
],
|
289 |
+
"execution_count": null,
|
290 |
+
"outputs": [],
|
291 |
+
"metadata": {
|
292 |
+
"datalore": {
|
293 |
+
"node_id": "1iPZOTKPwSkTgX5dORqT89",
|
294 |
+
"type": "CODE",
|
295 |
+
"hide_input_from_viewers": true,
|
296 |
+
"hide_output_from_viewers": true
|
297 |
+
}
|
298 |
+
}
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "code",
|
302 |
+
"source": [
|
303 |
+
"multi_label_trainer = MultiLabelTrainer(\n",
|
304 |
+
" model, \n",
|
305 |
+
" args, \n",
|
306 |
+
" train_dataset=ds_enc[\"train\"], \n",
|
307 |
+
" eval_dataset=ds_enc[\"test\"], \n",
|
308 |
+
" compute_metrics=compute_metrics, \n",
|
309 |
+
" tokenizer=tokenizer\n",
|
310 |
+
")"
|
311 |
+
],
|
312 |
+
"execution_count": null,
|
313 |
+
"outputs": [],
|
314 |
+
"metadata": {
|
315 |
+
"datalore": {
|
316 |
+
"node_id": "bnRkNvRYltLun6gCEgL7v0",
|
317 |
+
"type": "CODE",
|
318 |
+
"hide_input_from_viewers": true,
|
319 |
+
"hide_output_from_viewers": true
|
320 |
+
}
|
321 |
+
}
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "code",
|
325 |
+
"source": [
|
326 |
+
"multi_label_trainer.evaluate()"
|
327 |
+
],
|
328 |
+
"execution_count": null,
|
329 |
+
"outputs": [],
|
330 |
+
"metadata": {
|
331 |
+
"datalore": {
|
332 |
+
"node_id": "LO54PlDkWQdFrzV25FvduB",
|
333 |
+
"type": "CODE",
|
334 |
+
"hide_input_from_viewers": true,
|
335 |
+
"hide_output_from_viewers": true
|
336 |
+
}
|
337 |
+
}
|
338 |
+
},
|
339 |
+
{
|
340 |
+
"cell_type": "code",
|
341 |
+
"source": [
|
342 |
+
"multi_label_trainer.train()"
|
343 |
+
],
|
344 |
+
"execution_count": null,
|
345 |
+
"outputs": [],
|
346 |
+
"metadata": {
|
347 |
+
"datalore": {
|
348 |
+
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
|
349 |
+
"type": "CODE",
|
350 |
+
"hide_input_from_viewers": true,
|
351 |
+
"hide_output_from_viewers": true
|
352 |
+
}
|
353 |
+
}
|
354 |
+
},
|
355 |
+
{
|
356 |
+
"cell_type": "markdown",
|
357 |
+
"source": [
|
358 |
+
"### Evaluation"
|
359 |
+
],
|
360 |
+
"metadata": {
|
361 |
+
"collapsed": false
|
362 |
+
}
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"cell_type": "markdown",
|
366 |
+
"source": [
|
367 |
+
"We instantiate a classifier `pipeline` and push it to CUDA."
|
368 |
+
],
|
369 |
+
"metadata": {
|
370 |
+
"collapsed": false
|
371 |
+
}
|
372 |
+
},
|
373 |
+
{
|
374 |
+
"cell_type": "code",
|
375 |
+
"source": [
|
376 |
+
"classifier = pipeline(\"text-classification\", \n",
|
377 |
+
" model, \n",
|
378 |
+
" tokenizer=tokenizer, \n",
|
379 |
+
" device=\"cuda:0\")"
|
380 |
+
],
|
381 |
+
"execution_count": null,
|
382 |
+
"outputs": [],
|
383 |
+
"metadata": {
|
384 |
+
"datalore": {
|
385 |
+
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
|
386 |
+
"type": "CODE",
|
387 |
+
"hide_input_from_viewers": true,
|
388 |
+
"hide_output_from_viewers": true
|
389 |
+
}
|
390 |
+
}
|
391 |
+
},
|
392 |
+
{
|
393 |
+
"cell_type": "markdown",
|
394 |
+
"source": [
|
395 |
+
"We use the same tokenizer used for training to tokenize/encode the validation set."
|
396 |
+
],
|
397 |
+
"metadata": {
|
398 |
+
"collapsed": false
|
399 |
+
}
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "code",
|
403 |
+
"source": [
|
404 |
+
"test_encodings = tokenizer.batch_encode_plus(dataset[\"validate\"][\"text\"], \n",
|
405 |
+
" max_length=255, \n",
|
406 |
+
" pad_to_max_length=True, \n",
|
407 |
+
" return_token_type_ids=True, \n",
|
408 |
+
" truncation=True)"
|
409 |
+
],
|
410 |
+
"execution_count": null,
|
411 |
+
"outputs": [],
|
412 |
+
"metadata": {
|
413 |
+
"datalore": {
|
414 |
+
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
|
415 |
+
"type": "CODE",
|
416 |
+
"hide_input_from_viewers": true,
|
417 |
+
"hide_output_from_viewers": true
|
418 |
+
}
|
419 |
+
}
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "markdown",
|
423 |
+
"source": [
|
424 |
+
"Once we've made the data loadable by putting it into a `DataLoader`, we "
|
425 |
+
],
|
426 |
+
"metadata": {
|
427 |
+
"collapsed": false
|
428 |
+
}
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "code",
|
432 |
+
"source": [
|
433 |
+
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
|
434 |
+
" torch.tensor(test_encodings['attention_mask']), \n",
|
435 |
+
" torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n",
|
436 |
+
" torch.tensor(test_encodings['token_type_ids']))\n",
|
437 |
+
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
|
438 |
+
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
|
439 |
+
" batch_size=BATCH_SIZE)"
|
440 |
+
],
|
441 |
+
"execution_count": null,
|
442 |
+
"outputs": [],
|
443 |
+
"metadata": {
|
444 |
+
"datalore": {
|
445 |
+
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
|
446 |
+
"type": "CODE",
|
447 |
+
"hide_input_from_viewers": true,
|
448 |
+
"hide_output_from_viewers": true
|
449 |
+
}
|
450 |
+
}
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"cell_type": "code",
|
454 |
+
"source": [
|
455 |
+
"model.eval()\n",
|
456 |
+
"\n",
|
457 |
+
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
|
458 |
+
"\n",
|
459 |
+
"for i, batch in enumerate(test_dataloader):\n",
|
460 |
+
" batch = tuple(t.to(device) for t in batch)\n",
|
461 |
+
" # Unpack the inputs from our dataloader\n",
|
462 |
+
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
|
463 |
+
" \n",
|
464 |
+
" with torch.no_grad():\n",
|
465 |
+
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
|
466 |
+
" b_logit_pred = outs[0]\n",
|
467 |
+
" pred_label = torch.sigmoid(b_logit_pred)\n",
|
468 |
+
"\n",
|
469 |
+
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
|
470 |
+
" pred_label = pred_label.to('cpu').numpy()\n",
|
471 |
+
" b_labels = b_labels.to('cpu').numpy()\n",
|
472 |
+
"\n",
|
473 |
+
" tokenized_texts.append(b_input_ids)\n",
|
474 |
+
" logit_preds.append(b_logit_pred)\n",
|
475 |
+
" true_labels.append(b_labels)\n",
|
476 |
+
" pred_labels.append(pred_label)\n",
|
477 |
+
"\n",
|
478 |
+
"# Flatten outputs\n",
|
479 |
+
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
|
480 |
+
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
|
481 |
+
"true_labels = [item for sublist in true_labels for item in sublist]\n",
|
482 |
+
"\n",
|
483 |
+
"# Converting flattened binary values to boolean values\n",
|
484 |
+
"true_bools = [tl == 1 for tl in true_labels]\n",
|
485 |
+
"pred_bools = [pl > 0.50 for pl in pred_labels] "
|
486 |
+
],
|
487 |
+
"execution_count": null,
|
488 |
+
"outputs": [],
|
489 |
+
"metadata": {
|
490 |
+
"datalore": {
|
491 |
+
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
|
492 |
+
"type": "CODE",
|
493 |
+
"hide_input_from_viewers": true,
|
494 |
+
"hide_output_from_viewers": true
|
495 |
+
}
|
496 |
+
}
|
497 |
+
},
|
498 |
+
{
|
499 |
+
"cell_type": "markdown",
|
500 |
+
"source": [
|
501 |
+
"We create a classification report:"
|
502 |
+
],
|
503 |
+
"metadata": {
|
504 |
+
"collapsed": false
|
505 |
+
}
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"cell_type": "code",
|
509 |
+
"source": [
|
510 |
+
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
|
511 |
+
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
|
512 |
+
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
|
513 |
+
"print(clf_report)"
|
514 |
+
],
|
515 |
+
"execution_count": null,
|
516 |
+
"outputs": [],
|
517 |
+
"metadata": {
|
518 |
+
"datalore": {
|
519 |
+
"node_id": "eBprrgF086mznPbPVBpOLS",
|
520 |
+
"type": "CODE",
|
521 |
+
"hide_input_from_viewers": true,
|
522 |
+
"hide_output_from_viewers": true
|
523 |
+
}
|
524 |
+
}
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"cell_type": "markdown",
|
528 |
+
"source": [
|
529 |
+
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
|
530 |
+
],
|
531 |
+
"metadata": {
|
532 |
+
"collapsed": false
|
533 |
+
}
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "code",
|
537 |
+
"source": [
|
538 |
+
"# Creating a map of class names from class numbers\n",
|
539 |
+
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
|
540 |
+
],
|
541 |
+
"execution_count": null,
|
542 |
+
"outputs": [],
|
543 |
+
"metadata": {
|
544 |
+
"datalore": {
|
545 |
+
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
|
546 |
+
"type": "CODE",
|
547 |
+
"hide_input_from_viewers": true,
|
548 |
+
"hide_output_from_viewers": true
|
549 |
+
}
|
550 |
+
}
|
551 |
+
},
|
552 |
+
{
|
553 |
+
"cell_type": "code",
|
554 |
+
"source": [
|
555 |
+
"true_label_idxs, pred_label_idxs = [], []\n",
|
556 |
+
"\n",
|
557 |
+
"for vals in true_bools:\n",
|
558 |
+
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
|
559 |
+
"for vals in pred_bools:\n",
|
560 |
+
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
|
561 |
+
],
|
562 |
+
"execution_count": null,
|
563 |
+
"outputs": [],
|
564 |
+
"metadata": {
|
565 |
+
"datalore": {
|
566 |
+
"node_id": "jH0S35dDteUch01sa6me6e",
|
567 |
+
"type": "CODE",
|
568 |
+
"hide_input_from_viewers": true,
|
569 |
+
"hide_output_from_viewers": true
|
570 |
+
}
|
571 |
+
}
|
572 |
+
},
|
573 |
+
{
|
574 |
+
"cell_type": "code",
|
575 |
+
"source": [
|
576 |
+
"true_label_texts, pred_label_texts = [], []\n",
|
577 |
+
"\n",
|
578 |
+
"for vals in true_label_idxs:\n",
|
579 |
+
" if vals:\n",
|
580 |
+
" true_label_texts.append([idx2label[val] for val in vals])\n",
|
581 |
+
" else:\n",
|
582 |
+
" true_label_texts.append(vals)\n",
|
583 |
+
"\n",
|
584 |
+
"for vals in pred_label_idxs:\n",
|
585 |
+
" if vals:\n",
|
586 |
+
" pred_label_texts.append([idx2label[val] for val in vals])\n",
|
587 |
+
" else:\n",
|
588 |
+
" pred_label_texts.append(vals)"
|
589 |
+
],
|
590 |
+
"execution_count": null,
|
591 |
+
"outputs": [],
|
592 |
+
"metadata": {
|
593 |
+
"datalore": {
|
594 |
+
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
|
595 |
+
"type": "CODE",
|
596 |
+
"hide_input_from_viewers": true,
|
597 |
+
"hide_output_from_viewers": true
|
598 |
+
}
|
599 |
+
}
|
600 |
+
},
|
601 |
+
{
|
602 |
+
"cell_type": "code",
|
603 |
+
"source": [
|
604 |
+
"symptom_texts = [tokenizer.decode(text,\n",
|
605 |
+
" skip_special_tokens=True,\n",
|
606 |
+
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
|
607 |
+
],
|
608 |
+
"execution_count": null,
|
609 |
+
"outputs": [],
|
610 |
+
"metadata": {
|
611 |
+
"datalore": {
|
612 |
+
"node_id": "SxUmVHfQISEeptg1SawOmB",
|
613 |
+
"type": "CODE",
|
614 |
+
"hide_input_from_viewers": true,
|
615 |
+
"hide_output_from_viewers": true
|
616 |
+
}
|
617 |
+
}
|
618 |
+
},
|
619 |
+
{
|
620 |
+
"cell_type": "code",
|
621 |
+
"source": [
|
622 |
+
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
|
623 |
+
" 'true_labels': true_label_texts, \n",
|
624 |
+
" 'pred_labels':pred_label_texts})\n",
|
625 |
+
"comparisons_df.to_csv('comparisons.csv')\n",
|
626 |
+
"comparisons_df"
|
627 |
+
],
|
628 |
+
"execution_count": null,
|
629 |
+
"outputs": [],
|
630 |
+
"metadata": {
|
631 |
+
"datalore": {
|
632 |
+
"node_id": "BxFNigNGRLTOqraI55BPSH",
|
633 |
+
"type": "CODE",
|
634 |
+
"hide_input_from_viewers": true,
|
635 |
+
"hide_output_from_viewers": true
|
636 |
+
}
|
637 |
+
}
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"cell_type": "markdown",
|
641 |
+
"source": [
|
642 |
+
"### Shapley analysis"
|
643 |
+
],
|
644 |
+
"metadata": {
|
645 |
+
"collapsed": false
|
646 |
+
}
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"cell_type": "code",
|
650 |
+
"source": [
|
651 |
+
"explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
|
652 |
+
],
|
653 |
+
"execution_count": null,
|
654 |
+
"outputs": [],
|
655 |
+
"metadata": {
|
656 |
+
"datalore": {
|
657 |
+
"node_id": "OpdZcoenX2HwzLdai7K5UA",
|
658 |
+
"type": "CODE",
|
659 |
+
"hide_input_from_viewers": true,
|
660 |
+
"hide_output_from_viewers": true
|
661 |
+
}
|
662 |
+
}
|
663 |
+
},
|
664 |
+
{
|
665 |
+
"cell_type": "code",
|
666 |
+
"source": [
|
667 |
+
"shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])"
|
668 |
+
],
|
669 |
+
"execution_count": null,
|
670 |
+
"outputs": [],
|
671 |
+
"metadata": {
|
672 |
+
"datalore": {
|
673 |
+
"node_id": "FvbCMfIDlcf16YSvb8wNQv",
|
674 |
+
"type": "CODE",
|
675 |
+
"hide_input_from_viewers": true,
|
676 |
+
"hide_output_from_viewers": true
|
677 |
+
}
|
678 |
+
}
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"cell_type": "code",
|
682 |
+
"source": [
|
683 |
+
"shap.plots.text(shap_values)"
|
684 |
+
],
|
685 |
+
"execution_count": null,
|
686 |
+
"outputs": [],
|
687 |
+
"metadata": {
|
688 |
+
"datalore": {
|
689 |
+
"node_id": "TSxvakWLPCpjVMWi9ZdEbd",
|
690 |
+
"type": "CODE",
|
691 |
+
"hide_input_from_viewers": true,
|
692 |
+
"hide_output_from_viewers": true
|
693 |
+
}
|
694 |
+
}
|
695 |
+
}
|
696 |
+
],
|
697 |
+
"metadata": {
|
698 |
+
"kernelspec": {
|
699 |
+
"name": "python3",
|
700 |
+
"language": "python",
|
701 |
+
"display_name": "Python 3 (ipykernel)"
|
702 |
+
},
|
703 |
+
"datalore": {
|
704 |
+
"computation_mode": "JUPYTER",
|
705 |
+
"package_manager": "pip",
|
706 |
+
"base_environment": "default",
|
707 |
+
"packages": [
|
708 |
+
{
|
709 |
+
"name": "datasets",
|
710 |
+
"version": "2.16.1",
|
711 |
+
"source": "PIP"
|
712 |
+
},
|
713 |
+
{
|
714 |
+
"name": "torch",
|
715 |
+
"version": "2.1.2",
|
716 |
+
"source": "PIP"
|
717 |
+
},
|
718 |
+
{
|
719 |
+
"name": "accelerate",
|
720 |
+
"version": "0.26.1",
|
721 |
+
"source": "PIP"
|
722 |
+
}
|
723 |
+
],
|
724 |
+
"report_row_ids": [
|
725 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
726 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
727 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
728 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
729 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
730 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
731 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
732 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
733 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
734 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
735 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
736 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
737 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
738 |
+
],
|
739 |
+
"version": 3
|
740 |
+
}
|
741 |
+
},
|
742 |
+
"nbformat": 4,
|
743 |
+
"nbformat_minor": 4
|
744 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-1-52-4Z.ipynb
ADDED
@@ -0,0 +1,788 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {
|
11 |
+
"collapsed": false
|
12 |
+
}
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"source": [
|
17 |
+
"import pandas as pd\n",
|
18 |
+
"import numpy as np\n",
|
19 |
+
"import torch\n",
|
20 |
+
"import os\n",
|
21 |
+
"from typing import List\n",
|
22 |
+
"from datasets import load_dataset\n",
|
23 |
+
"import shap\n",
|
24 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
25 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
|
26 |
+
"\n",
|
27 |
+
"%load_ext watermark"
|
28 |
+
],
|
29 |
+
"outputs": [
|
30 |
+
{
|
31 |
+
"output_type": "error",
|
32 |
+
"ename": "ModuleNotFoundError",
|
33 |
+
"evalue": "No module named 'torch'",
|
34 |
+
"traceback": [
|
35 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
36 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m Traceback (most recent call last)",
|
37 |
+
"Cell \u001b[0;32mIn[2], line 3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mpandas\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mpd\u001b[39;00m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mnumpy\u001b[39;00m \u001b[38;5;28;01mas\u001b[39;00m \u001b[38;5;21;01mnp\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mtorch\u001b[39;00m\n\u001b[1;32m 4\u001b[0m \u001b[38;5;28;01mimport\u001b[39;00m \u001b[38;5;21;01mos\u001b[39;00m\n\u001b[1;32m 5\u001b[0m \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;21;01mtyping\u001b[39;00m \u001b[38;5;28;01mimport\u001b[39;00m List\n",
|
38 |
+
"\u001b[0;31mModuleNotFoundError\u001b[0m: No module named 'torch'"
|
39 |
+
]
|
40 |
+
}
|
41 |
+
],
|
42 |
+
"execution_count": 2,
|
43 |
+
"metadata": {
|
44 |
+
"datalore": {
|
45 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
46 |
+
"type": "CODE",
|
47 |
+
"hide_input_from_viewers": false,
|
48 |
+
"hide_output_from_viewers": false,
|
49 |
+
"report_properties": {
|
50 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
51 |
+
}
|
52 |
+
},
|
53 |
+
"gather": {
|
54 |
+
"logged": 1706406690290
|
55 |
+
}
|
56 |
+
}
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
62 |
+
"\n",
|
63 |
+
"SEED: int = 42\n",
|
64 |
+
"\n",
|
65 |
+
"BATCH_SIZE: int = 8\n",
|
66 |
+
"EPOCHS: int = 1\n",
|
67 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
68 |
+
"\n",
|
69 |
+
"CLASS_NAMES: List[str] = [\"DIED\",\n",
|
70 |
+
" \"ER_VISIT\",\n",
|
71 |
+
" \"HOSPITAL\",\n",
|
72 |
+
" \"OFC_VISIT\",\n",
|
73 |
+
" \"X_STAY\",\n",
|
74 |
+
" \"DISABLE\",\n",
|
75 |
+
" \"D_PRESENTED\"]\n",
|
76 |
+
"\n",
|
77 |
+
"# WandB configuration\n",
|
78 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
|
79 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
|
80 |
+
],
|
81 |
+
"outputs": [],
|
82 |
+
"execution_count": null,
|
83 |
+
"metadata": {
|
84 |
+
"collapsed": false
|
85 |
+
}
|
86 |
+
},
|
87 |
+
{
|
88 |
+
"cell_type": "code",
|
89 |
+
"source": [
|
90 |
+
"%watermark --iversion"
|
91 |
+
],
|
92 |
+
"outputs": [],
|
93 |
+
"execution_count": null,
|
94 |
+
"metadata": {
|
95 |
+
"collapsed": false
|
96 |
+
}
|
97 |
+
},
|
98 |
+
{
|
99 |
+
"cell_type": "code",
|
100 |
+
"source": [
|
101 |
+
"!nvidia-smi"
|
102 |
+
],
|
103 |
+
"outputs": [
|
104 |
+
{
|
105 |
+
"output_type": "stream",
|
106 |
+
"name": "stdout",
|
107 |
+
"text": "Sun Jan 28 01:31:42 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 28C P0 37W / 250W | 0MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 29C P0 36W / 250W | 0MiB / 16384MiB | 1% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
108 |
+
}
|
109 |
+
],
|
110 |
+
"execution_count": 4,
|
111 |
+
"metadata": {
|
112 |
+
"datalore": {
|
113 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
114 |
+
"type": "CODE",
|
115 |
+
"hide_input_from_viewers": true,
|
116 |
+
"hide_output_from_viewers": true
|
117 |
+
}
|
118 |
+
}
|
119 |
+
},
|
120 |
+
{
|
121 |
+
"attachments": {},
|
122 |
+
"cell_type": "markdown",
|
123 |
+
"source": [
|
124 |
+
"## Loading the data set"
|
125 |
+
],
|
126 |
+
"metadata": {
|
127 |
+
"datalore": {
|
128 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
129 |
+
"type": "MD",
|
130 |
+
"hide_input_from_viewers": false,
|
131 |
+
"hide_output_from_viewers": false,
|
132 |
+
"report_properties": {
|
133 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
134 |
+
}
|
135 |
+
}
|
136 |
+
}
|
137 |
+
},
|
138 |
+
{
|
139 |
+
"cell_type": "code",
|
140 |
+
"source": [
|
141 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
142 |
+
],
|
143 |
+
"outputs": [],
|
144 |
+
"execution_count": null,
|
145 |
+
"metadata": {
|
146 |
+
"collapsed": false
|
147 |
+
}
|
148 |
+
},
|
149 |
+
{
|
150 |
+
"cell_type": "markdown",
|
151 |
+
"source": [
|
152 |
+
"### Tokenisation and encoding"
|
153 |
+
],
|
154 |
+
"metadata": {
|
155 |
+
"collapsed": false
|
156 |
+
}
|
157 |
+
},
|
158 |
+
{
|
159 |
+
"cell_type": "code",
|
160 |
+
"source": [
|
161 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
|
162 |
+
],
|
163 |
+
"outputs": [],
|
164 |
+
"execution_count": null,
|
165 |
+
"metadata": {
|
166 |
+
"datalore": {
|
167 |
+
"node_id": "I7n646PIscsUZRoHu6m7zm",
|
168 |
+
"type": "CODE",
|
169 |
+
"hide_input_from_viewers": true,
|
170 |
+
"hide_output_from_viewers": true
|
171 |
+
}
|
172 |
+
}
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "code",
|
176 |
+
"source": [
|
177 |
+
"def tokenize_and_encode(examples):\n",
|
178 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
179 |
+
],
|
180 |
+
"outputs": [],
|
181 |
+
"execution_count": null,
|
182 |
+
"metadata": {
|
183 |
+
"datalore": {
|
184 |
+
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
|
185 |
+
"type": "CODE",
|
186 |
+
"hide_input_from_viewers": true,
|
187 |
+
"hide_output_from_viewers": true
|
188 |
+
}
|
189 |
+
}
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"source": [
|
194 |
+
"cols = dataset[\"train\"].column_names\n",
|
195 |
+
"cols.remove(\"labels\")\n",
|
196 |
+
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
197 |
+
],
|
198 |
+
"outputs": [],
|
199 |
+
"execution_count": null,
|
200 |
+
"metadata": {
|
201 |
+
"datalore": {
|
202 |
+
"node_id": "slHeNysZOX9uWS9PB7jFDb",
|
203 |
+
"type": "CODE",
|
204 |
+
"hide_input_from_viewers": true,
|
205 |
+
"hide_output_from_viewers": true
|
206 |
+
}
|
207 |
+
}
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "markdown",
|
211 |
+
"source": [
|
212 |
+
"### Training"
|
213 |
+
],
|
214 |
+
"metadata": {
|
215 |
+
"collapsed": false
|
216 |
+
}
|
217 |
+
},
|
218 |
+
{
|
219 |
+
"cell_type": "code",
|
220 |
+
"source": [
|
221 |
+
"class MultiLabelTrainer(Trainer):\n",
|
222 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
223 |
+
" labels = inputs.pop(\"labels\")\n",
|
224 |
+
" outputs = model(**inputs)\n",
|
225 |
+
" logits = outputs.logits\n",
|
226 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
227 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
228 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
229 |
+
" return (loss, outputs) if return_outputs else loss"
|
230 |
+
],
|
231 |
+
"outputs": [],
|
232 |
+
"execution_count": null,
|
233 |
+
"metadata": {
|
234 |
+
"datalore": {
|
235 |
+
"node_id": "itXWkbDw9sqbkMuDP84QoT",
|
236 |
+
"type": "CODE",
|
237 |
+
"hide_input_from_viewers": true,
|
238 |
+
"hide_output_from_viewers": true
|
239 |
+
}
|
240 |
+
}
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"source": [
|
245 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=num_labels).to(\"cuda\")"
|
246 |
+
],
|
247 |
+
"outputs": [],
|
248 |
+
"execution_count": null,
|
249 |
+
"metadata": {
|
250 |
+
"datalore": {
|
251 |
+
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
|
252 |
+
"type": "CODE",
|
253 |
+
"hide_input_from_viewers": true,
|
254 |
+
"hide_output_from_viewers": true
|
255 |
+
}
|
256 |
+
}
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"cell_type": "code",
|
260 |
+
"source": [
|
261 |
+
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
|
262 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
263 |
+
" y_true = torch.from_numpy(y_true)\n",
|
264 |
+
"\n",
|
265 |
+
" if sigmoid:\n",
|
266 |
+
" y_pred = y_pred.sigmoid()\n",
|
267 |
+
"\n",
|
268 |
+
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
|
269 |
+
],
|
270 |
+
"outputs": [],
|
271 |
+
"execution_count": null,
|
272 |
+
"metadata": {
|
273 |
+
"datalore": {
|
274 |
+
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
|
275 |
+
"type": "CODE",
|
276 |
+
"hide_input_from_viewers": true,
|
277 |
+
"hide_output_from_viewers": true
|
278 |
+
}
|
279 |
+
}
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "code",
|
283 |
+
"source": [
|
284 |
+
"def compute_metrics(eval_pred):\n",
|
285 |
+
" predictions, labels = eval_pred\n",
|
286 |
+
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
|
287 |
+
],
|
288 |
+
"outputs": [],
|
289 |
+
"execution_count": null,
|
290 |
+
"metadata": {
|
291 |
+
"datalore": {
|
292 |
+
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
|
293 |
+
"type": "CODE",
|
294 |
+
"hide_input_from_viewers": true,
|
295 |
+
"hide_output_from_viewers": true
|
296 |
+
}
|
297 |
+
}
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "code",
|
301 |
+
"source": [
|
302 |
+
"args = TrainingArguments(\n",
|
303 |
+
" output_dir=\"vaers\",\n",
|
304 |
+
" evaluation_strategy=\"epoch\",\n",
|
305 |
+
" learning_rate=2e-5,\n",
|
306 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
307 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
308 |
+
" num_train_epochs=EPOCHS,\n",
|
309 |
+
" weight_decay=.01,\n",
|
310 |
+
" report_to=[\"wandb\"]\n",
|
311 |
+
")"
|
312 |
+
],
|
313 |
+
"outputs": [],
|
314 |
+
"execution_count": null,
|
315 |
+
"metadata": {
|
316 |
+
"datalore": {
|
317 |
+
"node_id": "1iPZOTKPwSkTgX5dORqT89",
|
318 |
+
"type": "CODE",
|
319 |
+
"hide_input_from_viewers": true,
|
320 |
+
"hide_output_from_viewers": true
|
321 |
+
}
|
322 |
+
}
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "code",
|
326 |
+
"source": [
|
327 |
+
"multi_label_trainer = MultiLabelTrainer(\n",
|
328 |
+
" model, \n",
|
329 |
+
" args, \n",
|
330 |
+
" train_dataset=ds_enc[\"train\"], \n",
|
331 |
+
" eval_dataset=ds_enc[\"test\"], \n",
|
332 |
+
" compute_metrics=compute_metrics, \n",
|
333 |
+
" tokenizer=tokenizer\n",
|
334 |
+
")"
|
335 |
+
],
|
336 |
+
"outputs": [],
|
337 |
+
"execution_count": null,
|
338 |
+
"metadata": {
|
339 |
+
"datalore": {
|
340 |
+
"node_id": "bnRkNvRYltLun6gCEgL7v0",
|
341 |
+
"type": "CODE",
|
342 |
+
"hide_input_from_viewers": true,
|
343 |
+
"hide_output_from_viewers": true
|
344 |
+
}
|
345 |
+
}
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "code",
|
349 |
+
"source": [
|
350 |
+
"multi_label_trainer.evaluate()"
|
351 |
+
],
|
352 |
+
"outputs": [],
|
353 |
+
"execution_count": null,
|
354 |
+
"metadata": {
|
355 |
+
"datalore": {
|
356 |
+
"node_id": "LO54PlDkWQdFrzV25FvduB",
|
357 |
+
"type": "CODE",
|
358 |
+
"hide_input_from_viewers": true,
|
359 |
+
"hide_output_from_viewers": true
|
360 |
+
}
|
361 |
+
}
|
362 |
+
},
|
363 |
+
{
|
364 |
+
"cell_type": "code",
|
365 |
+
"source": [
|
366 |
+
"multi_label_trainer.train()"
|
367 |
+
],
|
368 |
+
"outputs": [],
|
369 |
+
"execution_count": null,
|
370 |
+
"metadata": {
|
371 |
+
"datalore": {
|
372 |
+
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
|
373 |
+
"type": "CODE",
|
374 |
+
"hide_input_from_viewers": true,
|
375 |
+
"hide_output_from_viewers": true
|
376 |
+
}
|
377 |
+
}
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"cell_type": "markdown",
|
381 |
+
"source": [
|
382 |
+
"### Evaluation"
|
383 |
+
],
|
384 |
+
"metadata": {
|
385 |
+
"collapsed": false
|
386 |
+
}
|
387 |
+
},
|
388 |
+
{
|
389 |
+
"cell_type": "markdown",
|
390 |
+
"source": [
|
391 |
+
"We instantiate a classifier `pipeline` and push it to CUDA."
|
392 |
+
],
|
393 |
+
"metadata": {
|
394 |
+
"collapsed": false
|
395 |
+
}
|
396 |
+
},
|
397 |
+
{
|
398 |
+
"cell_type": "code",
|
399 |
+
"source": [
|
400 |
+
"classifier = pipeline(\"text-classification\", \n",
|
401 |
+
" model, \n",
|
402 |
+
" tokenizer=tokenizer, \n",
|
403 |
+
" device=\"cuda:0\")"
|
404 |
+
],
|
405 |
+
"outputs": [],
|
406 |
+
"execution_count": null,
|
407 |
+
"metadata": {
|
408 |
+
"datalore": {
|
409 |
+
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
|
410 |
+
"type": "CODE",
|
411 |
+
"hide_input_from_viewers": true,
|
412 |
+
"hide_output_from_viewers": true
|
413 |
+
}
|
414 |
+
}
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"cell_type": "markdown",
|
418 |
+
"source": [
|
419 |
+
"We use the same tokenizer used for training to tokenize/encode the validation set."
|
420 |
+
],
|
421 |
+
"metadata": {
|
422 |
+
"collapsed": false
|
423 |
+
}
|
424 |
+
},
|
425 |
+
{
|
426 |
+
"cell_type": "code",
|
427 |
+
"source": [
|
428 |
+
"test_encodings = tokenizer.batch_encode_plus(dataset[\"validate\"][\"text\"], \n",
|
429 |
+
" max_length=255, \n",
|
430 |
+
" pad_to_max_length=True, \n",
|
431 |
+
" return_token_type_ids=True, \n",
|
432 |
+
" truncation=True)"
|
433 |
+
],
|
434 |
+
"outputs": [],
|
435 |
+
"execution_count": null,
|
436 |
+
"metadata": {
|
437 |
+
"datalore": {
|
438 |
+
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
|
439 |
+
"type": "CODE",
|
440 |
+
"hide_input_from_viewers": true,
|
441 |
+
"hide_output_from_viewers": true
|
442 |
+
}
|
443 |
+
}
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "markdown",
|
447 |
+
"source": [
|
448 |
+
"Once we've made the data loadable by putting it into a `DataLoader`, we "
|
449 |
+
],
|
450 |
+
"metadata": {
|
451 |
+
"collapsed": false
|
452 |
+
}
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "code",
|
456 |
+
"source": [
|
457 |
+
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
|
458 |
+
" torch.tensor(test_encodings['attention_mask']), \n",
|
459 |
+
" torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n",
|
460 |
+
" torch.tensor(test_encodings['token_type_ids']))\n",
|
461 |
+
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
|
462 |
+
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
|
463 |
+
" batch_size=BATCH_SIZE)"
|
464 |
+
],
|
465 |
+
"outputs": [],
|
466 |
+
"execution_count": null,
|
467 |
+
"metadata": {
|
468 |
+
"datalore": {
|
469 |
+
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
|
470 |
+
"type": "CODE",
|
471 |
+
"hide_input_from_viewers": true,
|
472 |
+
"hide_output_from_viewers": true
|
473 |
+
}
|
474 |
+
}
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"cell_type": "code",
|
478 |
+
"source": [
|
479 |
+
"model.eval()\n",
|
480 |
+
"\n",
|
481 |
+
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
|
482 |
+
"\n",
|
483 |
+
"for i, batch in enumerate(test_dataloader):\n",
|
484 |
+
" batch = tuple(t.to(device) for t in batch)\n",
|
485 |
+
" # Unpack the inputs from our dataloader\n",
|
486 |
+
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
|
487 |
+
" \n",
|
488 |
+
" with torch.no_grad():\n",
|
489 |
+
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
|
490 |
+
" b_logit_pred = outs[0]\n",
|
491 |
+
" pred_label = torch.sigmoid(b_logit_pred)\n",
|
492 |
+
"\n",
|
493 |
+
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
|
494 |
+
" pred_label = pred_label.to('cpu').numpy()\n",
|
495 |
+
" b_labels = b_labels.to('cpu').numpy()\n",
|
496 |
+
"\n",
|
497 |
+
" tokenized_texts.append(b_input_ids)\n",
|
498 |
+
" logit_preds.append(b_logit_pred)\n",
|
499 |
+
" true_labels.append(b_labels)\n",
|
500 |
+
" pred_labels.append(pred_label)\n",
|
501 |
+
"\n",
|
502 |
+
"# Flatten outputs\n",
|
503 |
+
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
|
504 |
+
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
|
505 |
+
"true_labels = [item for sublist in true_labels for item in sublist]\n",
|
506 |
+
"\n",
|
507 |
+
"# Converting flattened binary values to boolean values\n",
|
508 |
+
"true_bools = [tl == 1 for tl in true_labels]\n",
|
509 |
+
"pred_bools = [pl > 0.50 for pl in pred_labels] "
|
510 |
+
],
|
511 |
+
"outputs": [],
|
512 |
+
"execution_count": null,
|
513 |
+
"metadata": {
|
514 |
+
"datalore": {
|
515 |
+
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
|
516 |
+
"type": "CODE",
|
517 |
+
"hide_input_from_viewers": true,
|
518 |
+
"hide_output_from_viewers": true
|
519 |
+
}
|
520 |
+
}
|
521 |
+
},
|
522 |
+
{
|
523 |
+
"cell_type": "markdown",
|
524 |
+
"source": [
|
525 |
+
"We create a classification report:"
|
526 |
+
],
|
527 |
+
"metadata": {
|
528 |
+
"collapsed": false
|
529 |
+
}
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"cell_type": "code",
|
533 |
+
"source": [
|
534 |
+
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
|
535 |
+
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
|
536 |
+
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
|
537 |
+
"print(clf_report)"
|
538 |
+
],
|
539 |
+
"outputs": [],
|
540 |
+
"execution_count": null,
|
541 |
+
"metadata": {
|
542 |
+
"datalore": {
|
543 |
+
"node_id": "eBprrgF086mznPbPVBpOLS",
|
544 |
+
"type": "CODE",
|
545 |
+
"hide_input_from_viewers": true,
|
546 |
+
"hide_output_from_viewers": true
|
547 |
+
}
|
548 |
+
}
|
549 |
+
},
|
550 |
+
{
|
551 |
+
"cell_type": "markdown",
|
552 |
+
"source": [
|
553 |
+
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
|
554 |
+
],
|
555 |
+
"metadata": {
|
556 |
+
"collapsed": false
|
557 |
+
}
|
558 |
+
},
|
559 |
+
{
|
560 |
+
"cell_type": "code",
|
561 |
+
"source": [
|
562 |
+
"# Creating a map of class names from class numbers\n",
|
563 |
+
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
|
564 |
+
],
|
565 |
+
"outputs": [],
|
566 |
+
"execution_count": null,
|
567 |
+
"metadata": {
|
568 |
+
"datalore": {
|
569 |
+
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
|
570 |
+
"type": "CODE",
|
571 |
+
"hide_input_from_viewers": true,
|
572 |
+
"hide_output_from_viewers": true
|
573 |
+
}
|
574 |
+
}
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"cell_type": "code",
|
578 |
+
"source": [
|
579 |
+
"true_label_idxs, pred_label_idxs = [], []\n",
|
580 |
+
"\n",
|
581 |
+
"for vals in true_bools:\n",
|
582 |
+
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
|
583 |
+
"for vals in pred_bools:\n",
|
584 |
+
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
|
585 |
+
],
|
586 |
+
"outputs": [],
|
587 |
+
"execution_count": null,
|
588 |
+
"metadata": {
|
589 |
+
"datalore": {
|
590 |
+
"node_id": "jH0S35dDteUch01sa6me6e",
|
591 |
+
"type": "CODE",
|
592 |
+
"hide_input_from_viewers": true,
|
593 |
+
"hide_output_from_viewers": true
|
594 |
+
}
|
595 |
+
}
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "code",
|
599 |
+
"source": [
|
600 |
+
"true_label_texts, pred_label_texts = [], []\n",
|
601 |
+
"\n",
|
602 |
+
"for vals in true_label_idxs:\n",
|
603 |
+
" if vals:\n",
|
604 |
+
" true_label_texts.append([idx2label[val] for val in vals])\n",
|
605 |
+
" else:\n",
|
606 |
+
" true_label_texts.append(vals)\n",
|
607 |
+
"\n",
|
608 |
+
"for vals in pred_label_idxs:\n",
|
609 |
+
" if vals:\n",
|
610 |
+
" pred_label_texts.append([idx2label[val] for val in vals])\n",
|
611 |
+
" else:\n",
|
612 |
+
" pred_label_texts.append(vals)"
|
613 |
+
],
|
614 |
+
"outputs": [],
|
615 |
+
"execution_count": null,
|
616 |
+
"metadata": {
|
617 |
+
"datalore": {
|
618 |
+
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
|
619 |
+
"type": "CODE",
|
620 |
+
"hide_input_from_viewers": true,
|
621 |
+
"hide_output_from_viewers": true
|
622 |
+
}
|
623 |
+
}
|
624 |
+
},
|
625 |
+
{
|
626 |
+
"cell_type": "code",
|
627 |
+
"source": [
|
628 |
+
"symptom_texts = [tokenizer.decode(text,\n",
|
629 |
+
" skip_special_tokens=True,\n",
|
630 |
+
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
|
631 |
+
],
|
632 |
+
"outputs": [],
|
633 |
+
"execution_count": null,
|
634 |
+
"metadata": {
|
635 |
+
"datalore": {
|
636 |
+
"node_id": "SxUmVHfQISEeptg1SawOmB",
|
637 |
+
"type": "CODE",
|
638 |
+
"hide_input_from_viewers": true,
|
639 |
+
"hide_output_from_viewers": true
|
640 |
+
}
|
641 |
+
}
|
642 |
+
},
|
643 |
+
{
|
644 |
+
"cell_type": "code",
|
645 |
+
"source": [
|
646 |
+
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
|
647 |
+
" 'true_labels': true_label_texts, \n",
|
648 |
+
" 'pred_labels':pred_label_texts})\n",
|
649 |
+
"comparisons_df.to_csv('comparisons.csv')\n",
|
650 |
+
"comparisons_df"
|
651 |
+
],
|
652 |
+
"outputs": [],
|
653 |
+
"execution_count": null,
|
654 |
+
"metadata": {
|
655 |
+
"datalore": {
|
656 |
+
"node_id": "BxFNigNGRLTOqraI55BPSH",
|
657 |
+
"type": "CODE",
|
658 |
+
"hide_input_from_viewers": true,
|
659 |
+
"hide_output_from_viewers": true
|
660 |
+
}
|
661 |
+
}
|
662 |
+
},
|
663 |
+
{
|
664 |
+
"cell_type": "markdown",
|
665 |
+
"source": [
|
666 |
+
"### Shapley analysis"
|
667 |
+
],
|
668 |
+
"metadata": {
|
669 |
+
"collapsed": false
|
670 |
+
}
|
671 |
+
},
|
672 |
+
{
|
673 |
+
"cell_type": "code",
|
674 |
+
"source": [
|
675 |
+
"explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
|
676 |
+
],
|
677 |
+
"outputs": [],
|
678 |
+
"execution_count": null,
|
679 |
+
"metadata": {
|
680 |
+
"datalore": {
|
681 |
+
"node_id": "OpdZcoenX2HwzLdai7K5UA",
|
682 |
+
"type": "CODE",
|
683 |
+
"hide_input_from_viewers": true,
|
684 |
+
"hide_output_from_viewers": true
|
685 |
+
}
|
686 |
+
}
|
687 |
+
},
|
688 |
+
{
|
689 |
+
"cell_type": "code",
|
690 |
+
"source": [
|
691 |
+
"shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])"
|
692 |
+
],
|
693 |
+
"outputs": [],
|
694 |
+
"execution_count": null,
|
695 |
+
"metadata": {
|
696 |
+
"datalore": {
|
697 |
+
"node_id": "FvbCMfIDlcf16YSvb8wNQv",
|
698 |
+
"type": "CODE",
|
699 |
+
"hide_input_from_viewers": true,
|
700 |
+
"hide_output_from_viewers": true
|
701 |
+
}
|
702 |
+
}
|
703 |
+
},
|
704 |
+
{
|
705 |
+
"cell_type": "code",
|
706 |
+
"source": [
|
707 |
+
"shap.plots.text(shap_values)"
|
708 |
+
],
|
709 |
+
"outputs": [],
|
710 |
+
"execution_count": null,
|
711 |
+
"metadata": {
|
712 |
+
"datalore": {
|
713 |
+
"node_id": "TSxvakWLPCpjVMWi9ZdEbd",
|
714 |
+
"type": "CODE",
|
715 |
+
"hide_input_from_viewers": true,
|
716 |
+
"hide_output_from_viewers": true
|
717 |
+
}
|
718 |
+
}
|
719 |
+
}
|
720 |
+
],
|
721 |
+
"metadata": {
|
722 |
+
"kernelspec": {
|
723 |
+
"name": "python3",
|
724 |
+
"language": "python",
|
725 |
+
"display_name": "Python 3 (ipykernel)"
|
726 |
+
},
|
727 |
+
"datalore": {
|
728 |
+
"computation_mode": "JUPYTER",
|
729 |
+
"package_manager": "pip",
|
730 |
+
"base_environment": "default",
|
731 |
+
"packages": [
|
732 |
+
{
|
733 |
+
"name": "datasets",
|
734 |
+
"version": "2.16.1",
|
735 |
+
"source": "PIP"
|
736 |
+
},
|
737 |
+
{
|
738 |
+
"name": "torch",
|
739 |
+
"version": "2.1.2",
|
740 |
+
"source": "PIP"
|
741 |
+
},
|
742 |
+
{
|
743 |
+
"name": "accelerate",
|
744 |
+
"version": "0.26.1",
|
745 |
+
"source": "PIP"
|
746 |
+
}
|
747 |
+
],
|
748 |
+
"report_row_ids": [
|
749 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
750 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
751 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
752 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
753 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
754 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
755 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
756 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
757 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
758 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
759 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
760 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
761 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
762 |
+
],
|
763 |
+
"version": 3
|
764 |
+
},
|
765 |
+
"microsoft": {
|
766 |
+
"ms_spell_check": {
|
767 |
+
"ms_spell_check_language": "en"
|
768 |
+
}
|
769 |
+
},
|
770 |
+
"language_info": {
|
771 |
+
"name": "python",
|
772 |
+
"version": "3.8.5",
|
773 |
+
"mimetype": "text/x-python",
|
774 |
+
"codemirror_mode": {
|
775 |
+
"name": "ipython",
|
776 |
+
"version": 3
|
777 |
+
},
|
778 |
+
"pygments_lexer": "ipython3",
|
779 |
+
"nbconvert_exporter": "python",
|
780 |
+
"file_extension": ".py"
|
781 |
+
},
|
782 |
+
"nteract": {
|
783 |
+
"version": "nteract-front-end@1.0.0"
|
784 |
+
}
|
785 |
+
},
|
786 |
+
"nbformat": 4,
|
787 |
+
"nbformat_minor": 4
|
788 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-13-2-30Z.ipynb
ADDED
@@ -0,0 +1,1147 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {
|
11 |
+
"collapsed": false
|
12 |
+
}
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"source": [
|
17 |
+
"%pip install accelerate -U"
|
18 |
+
],
|
19 |
+
"outputs": [
|
20 |
+
{
|
21 |
+
"output_type": "stream",
|
22 |
+
"name": "stdout",
|
23 |
+
"text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nNote: you may need to restart the kernel to use updated packages.\n"
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"execution_count": 1,
|
27 |
+
"metadata": {
|
28 |
+
"jupyter": {
|
29 |
+
"source_hidden": false,
|
30 |
+
"outputs_hidden": false
|
31 |
+
},
|
32 |
+
"nteract": {
|
33 |
+
"transient": {
|
34 |
+
"deleting": false
|
35 |
+
}
|
36 |
+
}
|
37 |
+
}
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"source": [
|
42 |
+
"%pip install transformers datasets shap watermark wandb"
|
43 |
+
],
|
44 |
+
"outputs": [
|
45 |
+
{
|
46 |
+
"output_type": "stream",
|
47 |
+
"name": "stdout",
|
48 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nNote: you may need to restart the kernel to use updated packages.\n"
|
49 |
+
}
|
50 |
+
],
|
51 |
+
"execution_count": 2,
|
52 |
+
"metadata": {
|
53 |
+
"jupyter": {
|
54 |
+
"source_hidden": false,
|
55 |
+
"outputs_hidden": false
|
56 |
+
},
|
57 |
+
"nteract": {
|
58 |
+
"transient": {
|
59 |
+
"deleting": false
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "code",
|
66 |
+
"source": [
|
67 |
+
"import pandas as pd\n",
|
68 |
+
"import numpy as np\n",
|
69 |
+
"import torch\n",
|
70 |
+
"import os\n",
|
71 |
+
"from typing import List\n",
|
72 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
73 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
|
74 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
75 |
+
"from pyarrow import Table\n",
|
76 |
+
"import shap\n",
|
77 |
+
"\n",
|
78 |
+
"%load_ext watermark"
|
79 |
+
],
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"output_type": "stream",
|
83 |
+
"name": "stderr",
|
84 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 04:14:37.393442: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 04:14:38.436146: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 04:14:38.436275: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 04:14:38.436289: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
85 |
+
}
|
86 |
+
],
|
87 |
+
"execution_count": 3,
|
88 |
+
"metadata": {
|
89 |
+
"datalore": {
|
90 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
91 |
+
"type": "CODE",
|
92 |
+
"hide_input_from_viewers": false,
|
93 |
+
"hide_output_from_viewers": false,
|
94 |
+
"report_properties": {
|
95 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
96 |
+
}
|
97 |
+
},
|
98 |
+
"gather": {
|
99 |
+
"logged": 1706415280692
|
100 |
+
}
|
101 |
+
}
|
102 |
+
},
|
103 |
+
{
|
104 |
+
"cell_type": "code",
|
105 |
+
"source": [
|
106 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
107 |
+
"\n",
|
108 |
+
"SEED: int = 42\n",
|
109 |
+
"\n",
|
110 |
+
"BATCH_SIZE: int = 8\n",
|
111 |
+
"EPOCHS: int = 1\n",
|
112 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
113 |
+
"\n",
|
114 |
+
"CLASS_NAMES: List[str] = [\"DIED\",\n",
|
115 |
+
" \"ER_VISIT\",\n",
|
116 |
+
" \"HOSPITAL\",\n",
|
117 |
+
" \"OFC_VISIT\",\n",
|
118 |
+
" #\"X_STAY\", # pruned\n",
|
119 |
+
" #\"DISABLE\", # pruned\n",
|
120 |
+
" #\"D_PRESENTED\" # pruned\n",
|
121 |
+
" ]\n",
|
122 |
+
"\n",
|
123 |
+
"\n",
|
124 |
+
"\n",
|
125 |
+
"\n",
|
126 |
+
"# WandB configuration\n",
|
127 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
|
128 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
|
129 |
+
],
|
130 |
+
"outputs": [],
|
131 |
+
"execution_count": 4,
|
132 |
+
"metadata": {
|
133 |
+
"collapsed": false,
|
134 |
+
"gather": {
|
135 |
+
"logged": 1706415281102
|
136 |
+
}
|
137 |
+
}
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"source": [
|
142 |
+
"%watermark --iversion"
|
143 |
+
],
|
144 |
+
"outputs": [
|
145 |
+
{
|
146 |
+
"output_type": "stream",
|
147 |
+
"name": "stdout",
|
148 |
+
"text": "re : 2.2.1\nlogging: 0.5.1.2\nnumpy : 1.23.5\nshap : 0.44.1\npandas : 2.0.2\ntorch : 1.12.0\n\n"
|
149 |
+
}
|
150 |
+
],
|
151 |
+
"execution_count": 5,
|
152 |
+
"metadata": {
|
153 |
+
"collapsed": false
|
154 |
+
}
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"source": [
|
159 |
+
"!nvidia-smi"
|
160 |
+
],
|
161 |
+
"outputs": [
|
162 |
+
{
|
163 |
+
"output_type": "stream",
|
164 |
+
"name": "stdout",
|
165 |
+
"text": "Sun Jan 28 04:14:40 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 29C P0 37W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 28C P0 36W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
166 |
+
}
|
167 |
+
],
|
168 |
+
"execution_count": 6,
|
169 |
+
"metadata": {
|
170 |
+
"datalore": {
|
171 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
172 |
+
"type": "CODE",
|
173 |
+
"hide_input_from_viewers": true,
|
174 |
+
"hide_output_from_viewers": true
|
175 |
+
}
|
176 |
+
}
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"attachments": {},
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"source": [
|
182 |
+
"## Loading the data set"
|
183 |
+
],
|
184 |
+
"metadata": {
|
185 |
+
"datalore": {
|
186 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
187 |
+
"type": "MD",
|
188 |
+
"hide_input_from_viewers": false,
|
189 |
+
"hide_output_from_viewers": false,
|
190 |
+
"report_properties": {
|
191 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
192 |
+
}
|
193 |
+
}
|
194 |
+
}
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"source": [
|
199 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
200 |
+
],
|
201 |
+
"outputs": [],
|
202 |
+
"execution_count": 7,
|
203 |
+
"metadata": {
|
204 |
+
"collapsed": false,
|
205 |
+
"gather": {
|
206 |
+
"logged": 1706415283301
|
207 |
+
}
|
208 |
+
}
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "markdown",
|
212 |
+
"source": [
|
213 |
+
"We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
|
214 |
+
],
|
215 |
+
"metadata": {
|
216 |
+
"nteract": {
|
217 |
+
"transient": {
|
218 |
+
"deleting": false
|
219 |
+
}
|
220 |
+
}
|
221 |
+
}
|
222 |
+
},
|
223 |
+
{
|
224 |
+
"cell_type": "code",
|
225 |
+
"source": [
|
226 |
+
"ds = DatasetDict()\n",
|
227 |
+
"\n",
|
228 |
+
"for i in [\"test\", \"train\", \"val\"]:\n",
|
229 |
+
" tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
|
230 |
+
" ds[i] = Dataset(tab)\n",
|
231 |
+
"\n",
|
232 |
+
"dataset = ds"
|
233 |
+
],
|
234 |
+
"outputs": [],
|
235 |
+
"execution_count": 8,
|
236 |
+
"metadata": {
|
237 |
+
"jupyter": {
|
238 |
+
"source_hidden": false,
|
239 |
+
"outputs_hidden": false
|
240 |
+
},
|
241 |
+
"nteract": {
|
242 |
+
"transient": {
|
243 |
+
"deleting": false
|
244 |
+
}
|
245 |
+
},
|
246 |
+
"gather": {
|
247 |
+
"logged": 1706415283944
|
248 |
+
}
|
249 |
+
}
|
250 |
+
},
|
251 |
+
{
|
252 |
+
"cell_type": "markdown",
|
253 |
+
"source": [
|
254 |
+
"### Tokenisation and encoding"
|
255 |
+
],
|
256 |
+
"metadata": {
|
257 |
+
"collapsed": false
|
258 |
+
}
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "code",
|
262 |
+
"source": [
|
263 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
|
264 |
+
],
|
265 |
+
"outputs": [],
|
266 |
+
"execution_count": 9,
|
267 |
+
"metadata": {
|
268 |
+
"datalore": {
|
269 |
+
"node_id": "I7n646PIscsUZRoHu6m7zm",
|
270 |
+
"type": "CODE",
|
271 |
+
"hide_input_from_viewers": true,
|
272 |
+
"hide_output_from_viewers": true
|
273 |
+
},
|
274 |
+
"gather": {
|
275 |
+
"logged": 1706415284206
|
276 |
+
}
|
277 |
+
}
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"cell_type": "code",
|
281 |
+
"source": [
|
282 |
+
"def tokenize_and_encode(examples):\n",
|
283 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
284 |
+
],
|
285 |
+
"outputs": [],
|
286 |
+
"execution_count": 10,
|
287 |
+
"metadata": {
|
288 |
+
"datalore": {
|
289 |
+
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
|
290 |
+
"type": "CODE",
|
291 |
+
"hide_input_from_viewers": true,
|
292 |
+
"hide_output_from_viewers": true
|
293 |
+
},
|
294 |
+
"gather": {
|
295 |
+
"logged": 1706415284614
|
296 |
+
}
|
297 |
+
}
|
298 |
+
},
|
299 |
+
{
|
300 |
+
"cell_type": "code",
|
301 |
+
"source": [
|
302 |
+
"cols = dataset[\"train\"].column_names\n",
|
303 |
+
"cols.remove(\"labels\")\n",
|
304 |
+
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
305 |
+
],
|
306 |
+
"outputs": [
|
307 |
+
{
|
308 |
+
"output_type": "stream",
|
309 |
+
"name": "stderr",
|
310 |
+
"text": "Map: 100%|██████████| 15786/15786 [00:01<00:00, 10213.76 examples/s]\nMap: 100%|██████████| 73667/73667 [00:07<00:00, 10215.55 examples/s]\nMap: 100%|██████████| 15785/15785 [00:01<00:00, 10172.52 examples/s]\n"
|
311 |
+
}
|
312 |
+
],
|
313 |
+
"execution_count": 11,
|
314 |
+
"metadata": {
|
315 |
+
"datalore": {
|
316 |
+
"node_id": "slHeNysZOX9uWS9PB7jFDb",
|
317 |
+
"type": "CODE",
|
318 |
+
"hide_input_from_viewers": true,
|
319 |
+
"hide_output_from_viewers": true
|
320 |
+
},
|
321 |
+
"gather": {
|
322 |
+
"logged": 1706415294450
|
323 |
+
}
|
324 |
+
}
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"source": [
|
329 |
+
"### Training"
|
330 |
+
],
|
331 |
+
"metadata": {
|
332 |
+
"collapsed": false
|
333 |
+
}
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"cell_type": "code",
|
337 |
+
"source": [
|
338 |
+
"class MultiLabelTrainer(Trainer):\n",
|
339 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
340 |
+
" labels = inputs.pop(\"labels\")\n",
|
341 |
+
" outputs = model(**inputs)\n",
|
342 |
+
" logits = outputs.logits\n",
|
343 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
344 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
345 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
346 |
+
" return (loss, outputs) if return_outputs else loss"
|
347 |
+
],
|
348 |
+
"outputs": [],
|
349 |
+
"execution_count": 12,
|
350 |
+
"metadata": {
|
351 |
+
"datalore": {
|
352 |
+
"node_id": "itXWkbDw9sqbkMuDP84QoT",
|
353 |
+
"type": "CODE",
|
354 |
+
"hide_input_from_viewers": true,
|
355 |
+
"hide_output_from_viewers": true
|
356 |
+
},
|
357 |
+
"gather": {
|
358 |
+
"logged": 1706415294807
|
359 |
+
}
|
360 |
+
}
|
361 |
+
},
|
362 |
+
{
|
363 |
+
"cell_type": "code",
|
364 |
+
"source": [
|
365 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
|
366 |
+
],
|
367 |
+
"outputs": [
|
368 |
+
{
|
369 |
+
"output_type": "stream",
|
370 |
+
"name": "stderr",
|
371 |
+
"text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
372 |
+
}
|
373 |
+
],
|
374 |
+
"execution_count": 13,
|
375 |
+
"metadata": {
|
376 |
+
"datalore": {
|
377 |
+
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
|
378 |
+
"type": "CODE",
|
379 |
+
"hide_input_from_viewers": true,
|
380 |
+
"hide_output_from_viewers": true
|
381 |
+
},
|
382 |
+
"gather": {
|
383 |
+
"logged": 1706415296683
|
384 |
+
}
|
385 |
+
}
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"cell_type": "code",
|
389 |
+
"source": [
|
390 |
+
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
|
391 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
392 |
+
" y_true = torch.from_numpy(y_true)\n",
|
393 |
+
"\n",
|
394 |
+
" if sigmoid:\n",
|
395 |
+
" y_pred = y_pred.sigmoid()\n",
|
396 |
+
"\n",
|
397 |
+
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
|
398 |
+
],
|
399 |
+
"outputs": [],
|
400 |
+
"execution_count": 14,
|
401 |
+
"metadata": {
|
402 |
+
"datalore": {
|
403 |
+
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
|
404 |
+
"type": "CODE",
|
405 |
+
"hide_input_from_viewers": true,
|
406 |
+
"hide_output_from_viewers": true
|
407 |
+
},
|
408 |
+
"gather": {
|
409 |
+
"logged": 1706415296937
|
410 |
+
}
|
411 |
+
}
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"cell_type": "code",
|
415 |
+
"source": [
|
416 |
+
"def compute_metrics(eval_pred):\n",
|
417 |
+
" predictions, labels = eval_pred\n",
|
418 |
+
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
|
419 |
+
],
|
420 |
+
"outputs": [],
|
421 |
+
"execution_count": 15,
|
422 |
+
"metadata": {
|
423 |
+
"datalore": {
|
424 |
+
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
|
425 |
+
"type": "CODE",
|
426 |
+
"hide_input_from_viewers": true,
|
427 |
+
"hide_output_from_viewers": true
|
428 |
+
},
|
429 |
+
"gather": {
|
430 |
+
"logged": 1706415297280
|
431 |
+
}
|
432 |
+
}
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"cell_type": "code",
|
436 |
+
"source": [
|
437 |
+
"args = TrainingArguments(\n",
|
438 |
+
" output_dir=\"vaers\",\n",
|
439 |
+
" evaluation_strategy=\"epoch\",\n",
|
440 |
+
" learning_rate=2e-5,\n",
|
441 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
442 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
443 |
+
" num_train_epochs=EPOCHS,\n",
|
444 |
+
" weight_decay=.01,\n",
|
445 |
+
" report_to=[\"wandb\"]\n",
|
446 |
+
")"
|
447 |
+
],
|
448 |
+
"outputs": [],
|
449 |
+
"execution_count": 16,
|
450 |
+
"metadata": {
|
451 |
+
"datalore": {
|
452 |
+
"node_id": "1iPZOTKPwSkTgX5dORqT89",
|
453 |
+
"type": "CODE",
|
454 |
+
"hide_input_from_viewers": true,
|
455 |
+
"hide_output_from_viewers": true
|
456 |
+
},
|
457 |
+
"gather": {
|
458 |
+
"logged": 1706415297551
|
459 |
+
}
|
460 |
+
}
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"cell_type": "code",
|
464 |
+
"source": [
|
465 |
+
"multi_label_trainer = MultiLabelTrainer(\n",
|
466 |
+
" model, \n",
|
467 |
+
" args, \n",
|
468 |
+
" train_dataset=ds_enc[\"train\"], \n",
|
469 |
+
" eval_dataset=ds_enc[\"test\"], \n",
|
470 |
+
" compute_metrics=compute_metrics, \n",
|
471 |
+
" tokenizer=tokenizer\n",
|
472 |
+
")"
|
473 |
+
],
|
474 |
+
"outputs": [
|
475 |
+
{
|
476 |
+
"output_type": "stream",
|
477 |
+
"name": "stderr",
|
478 |
+
"text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
479 |
+
}
|
480 |
+
],
|
481 |
+
"execution_count": 17,
|
482 |
+
"metadata": {
|
483 |
+
"datalore": {
|
484 |
+
"node_id": "bnRkNvRYltLun6gCEgL7v0",
|
485 |
+
"type": "CODE",
|
486 |
+
"hide_input_from_viewers": true,
|
487 |
+
"hide_output_from_viewers": true
|
488 |
+
},
|
489 |
+
"gather": {
|
490 |
+
"logged": 1706415297795
|
491 |
+
}
|
492 |
+
}
|
493 |
+
},
|
494 |
+
{
|
495 |
+
"cell_type": "code",
|
496 |
+
"source": [
|
497 |
+
"multi_label_trainer.evaluate()"
|
498 |
+
],
|
499 |
+
"outputs": [
|
500 |
+
{
|
501 |
+
"output_type": "display_data",
|
502 |
+
"data": {
|
503 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
504 |
+
"text/html": "\n <div>\n \n <progress value='987' max='987' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [987/987 01:13]\n </div>\n "
|
505 |
+
},
|
506 |
+
"metadata": {}
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"output_type": "stream",
|
510 |
+
"name": "stderr",
|
511 |
+
"text": "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
512 |
+
},
|
513 |
+
{
|
514 |
+
"output_type": "display_data",
|
515 |
+
"data": {
|
516 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
517 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
518 |
+
},
|
519 |
+
"metadata": {}
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"output_type": "display_data",
|
523 |
+
"data": {
|
524 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
525 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_041615-nnw129w4</code>"
|
526 |
+
},
|
527 |
+
"metadata": {}
|
528 |
+
},
|
529 |
+
{
|
530 |
+
"output_type": "display_data",
|
531 |
+
"data": {
|
532 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
533 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/nnw129w4' target=\"_blank\">grateful-shadow-2</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
534 |
+
},
|
535 |
+
"metadata": {}
|
536 |
+
},
|
537 |
+
{
|
538 |
+
"output_type": "display_data",
|
539 |
+
"data": {
|
540 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
541 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
|
542 |
+
},
|
543 |
+
"metadata": {}
|
544 |
+
},
|
545 |
+
{
|
546 |
+
"output_type": "display_data",
|
547 |
+
"data": {
|
548 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
549 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/nnw129w4' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/nnw129w4</a>"
|
550 |
+
},
|
551 |
+
"metadata": {}
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"output_type": "execute_result",
|
555 |
+
"execution_count": 18,
|
556 |
+
"data": {
|
557 |
+
"text/plain": "{'eval_loss': 0.712559163570404,\n 'eval_accuracy_thresh': 0.36481693387031555,\n 'eval_runtime': 76.4156,\n 'eval_samples_per_second': 206.581,\n 'eval_steps_per_second': 12.916}"
|
558 |
+
},
|
559 |
+
"metadata": {}
|
560 |
+
}
|
561 |
+
],
|
562 |
+
"execution_count": 18,
|
563 |
+
"metadata": {
|
564 |
+
"datalore": {
|
565 |
+
"node_id": "LO54PlDkWQdFrzV25FvduB",
|
566 |
+
"type": "CODE",
|
567 |
+
"hide_input_from_viewers": true,
|
568 |
+
"hide_output_from_viewers": true
|
569 |
+
},
|
570 |
+
"gather": {
|
571 |
+
"logged": 1706415378024
|
572 |
+
}
|
573 |
+
}
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"cell_type": "code",
|
577 |
+
"source": [
|
578 |
+
"multi_label_trainer.train()"
|
579 |
+
],
|
580 |
+
"outputs": [
|
581 |
+
{
|
582 |
+
"output_type": "display_data",
|
583 |
+
"data": {
|
584 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
585 |
+
"text/html": "\n <div>\n \n <progress value='3001' max='4605' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [3001/4605 12:05 < 06:28, 4.13 it/s, Epoch 0.65/1]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
|
586 |
+
},
|
587 |
+
"metadata": {}
|
588 |
+
},
|
589 |
+
{
|
590 |
+
"output_type": "stream",
|
591 |
+
"name": "stderr",
|
592 |
+
"text": "Checkpoint destination directory vaers/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.2s\nCheckpoint destination directory vaers/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 13.4s\nCheckpoint destination directory vaers/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 13.0s\nCheckpoint destination directory vaers/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 11.6s\nCheckpoint destination directory vaers/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 14.6s\nCheckpoint destination directory vaers/checkpoint-3000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... "
|
593 |
+
}
|
594 |
+
],
|
595 |
+
"execution_count": 19,
|
596 |
+
"metadata": {
|
597 |
+
"datalore": {
|
598 |
+
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
|
599 |
+
"type": "CODE",
|
600 |
+
"hide_input_from_viewers": true,
|
601 |
+
"hide_output_from_viewers": true
|
602 |
+
},
|
603 |
+
"gather": {
|
604 |
+
"logged": 1706411445752
|
605 |
+
}
|
606 |
+
}
|
607 |
+
},
|
608 |
+
{
|
609 |
+
"cell_type": "markdown",
|
610 |
+
"source": [
|
611 |
+
"### Evaluation"
|
612 |
+
],
|
613 |
+
"metadata": {
|
614 |
+
"collapsed": false
|
615 |
+
}
|
616 |
+
},
|
617 |
+
{
|
618 |
+
"cell_type": "markdown",
|
619 |
+
"source": [
|
620 |
+
"We instantiate a classifier `pipeline` and push it to CUDA."
|
621 |
+
],
|
622 |
+
"metadata": {
|
623 |
+
"collapsed": false
|
624 |
+
}
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"cell_type": "code",
|
628 |
+
"source": [
|
629 |
+
"classifier = pipeline(\"text-classification\", \n",
|
630 |
+
" model, \n",
|
631 |
+
" tokenizer=tokenizer, \n",
|
632 |
+
" device=\"cuda:0\")"
|
633 |
+
],
|
634 |
+
"outputs": [],
|
635 |
+
"execution_count": null,
|
636 |
+
"metadata": {
|
637 |
+
"datalore": {
|
638 |
+
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
|
639 |
+
"type": "CODE",
|
640 |
+
"hide_input_from_viewers": true,
|
641 |
+
"hide_output_from_viewers": true
|
642 |
+
},
|
643 |
+
"gather": {
|
644 |
+
"logged": 1706411459928
|
645 |
+
}
|
646 |
+
}
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"cell_type": "markdown",
|
650 |
+
"source": [
|
651 |
+
"We use the same tokenizer used for training to tokenize/encode the validation set."
|
652 |
+
],
|
653 |
+
"metadata": {
|
654 |
+
"collapsed": false
|
655 |
+
}
|
656 |
+
},
|
657 |
+
{
|
658 |
+
"cell_type": "code",
|
659 |
+
"source": [
|
660 |
+
"test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
|
661 |
+
" max_length=None, \n",
|
662 |
+
" padding='max_length', \n",
|
663 |
+
" return_token_type_ids=True, \n",
|
664 |
+
" truncation=True)"
|
665 |
+
],
|
666 |
+
"outputs": [],
|
667 |
+
"execution_count": null,
|
668 |
+
"metadata": {
|
669 |
+
"datalore": {
|
670 |
+
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
|
671 |
+
"type": "CODE",
|
672 |
+
"hide_input_from_viewers": true,
|
673 |
+
"hide_output_from_viewers": true
|
674 |
+
},
|
675 |
+
"gather": {
|
676 |
+
"logged": 1706411523285
|
677 |
+
}
|
678 |
+
}
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"cell_type": "markdown",
|
682 |
+
"source": [
|
683 |
+
"Once we've made the data loadable by putting it into a `DataLoader`, we "
|
684 |
+
],
|
685 |
+
"metadata": {
|
686 |
+
"collapsed": false
|
687 |
+
}
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"cell_type": "code",
|
691 |
+
"source": [
|
692 |
+
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
|
693 |
+
" torch.tensor(test_encodings['attention_mask']), \n",
|
694 |
+
" torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
|
695 |
+
" torch.tensor(test_encodings['token_type_ids']))\n",
|
696 |
+
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
|
697 |
+
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
|
698 |
+
" batch_size=BATCH_SIZE)"
|
699 |
+
],
|
700 |
+
"outputs": [],
|
701 |
+
"execution_count": null,
|
702 |
+
"metadata": {
|
703 |
+
"datalore": {
|
704 |
+
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
|
705 |
+
"type": "CODE",
|
706 |
+
"hide_input_from_viewers": true,
|
707 |
+
"hide_output_from_viewers": true
|
708 |
+
},
|
709 |
+
"gather": {
|
710 |
+
"logged": 1706411543379
|
711 |
+
}
|
712 |
+
}
|
713 |
+
},
|
714 |
+
{
|
715 |
+
"cell_type": "code",
|
716 |
+
"source": [
|
717 |
+
"model.eval()\n",
|
718 |
+
"\n",
|
719 |
+
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
|
720 |
+
"\n",
|
721 |
+
"for i, batch in enumerate(test_dataloader):\n",
|
722 |
+
" batch = tuple(t.to(device) for t in batch)\n",
|
723 |
+
" \n",
|
724 |
+
" # Unpack the inputs from our dataloader\n",
|
725 |
+
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
|
726 |
+
" \n",
|
727 |
+
" with torch.no_grad():\n",
|
728 |
+
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
|
729 |
+
" b_logit_pred = outs[0]\n",
|
730 |
+
" pred_label = torch.sigmoid(b_logit_pred)\n",
|
731 |
+
"\n",
|
732 |
+
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
|
733 |
+
" pred_label = pred_label.to('cpu').numpy()\n",
|
734 |
+
" b_labels = b_labels.to('cpu').numpy()\n",
|
735 |
+
"\n",
|
736 |
+
" tokenized_texts.append(b_input_ids)\n",
|
737 |
+
" logit_preds.append(b_logit_pred)\n",
|
738 |
+
" true_labels.append(b_labels)\n",
|
739 |
+
" pred_labels.append(pred_label)\n",
|
740 |
+
"\n",
|
741 |
+
"# Flatten outputs\n",
|
742 |
+
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
|
743 |
+
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
|
744 |
+
"true_labels = [item for sublist in true_labels for item in sublist]\n",
|
745 |
+
"\n",
|
746 |
+
"# Converting flattened binary values to boolean values\n",
|
747 |
+
"true_bools = [tl == 1 for tl in true_labels]\n",
|
748 |
+
"pred_bools = [pl > 0.50 for pl in pred_labels] "
|
749 |
+
],
|
750 |
+
"outputs": [],
|
751 |
+
"execution_count": null,
|
752 |
+
"metadata": {
|
753 |
+
"datalore": {
|
754 |
+
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
|
755 |
+
"type": "CODE",
|
756 |
+
"hide_input_from_viewers": true,
|
757 |
+
"hide_output_from_viewers": true
|
758 |
+
},
|
759 |
+
"gather": {
|
760 |
+
"logged": 1706411587843
|
761 |
+
}
|
762 |
+
}
|
763 |
+
},
|
764 |
+
{
|
765 |
+
"cell_type": "markdown",
|
766 |
+
"source": [
|
767 |
+
"We create a classification report:"
|
768 |
+
],
|
769 |
+
"metadata": {
|
770 |
+
"collapsed": false
|
771 |
+
}
|
772 |
+
},
|
773 |
+
{
|
774 |
+
"cell_type": "code",
|
775 |
+
"source": [
|
776 |
+
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
|
777 |
+
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
|
778 |
+
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
|
779 |
+
"print(clf_report)"
|
780 |
+
],
|
781 |
+
"outputs": [],
|
782 |
+
"execution_count": null,
|
783 |
+
"metadata": {
|
784 |
+
"datalore": {
|
785 |
+
"node_id": "eBprrgF086mznPbPVBpOLS",
|
786 |
+
"type": "CODE",
|
787 |
+
"hide_input_from_viewers": true,
|
788 |
+
"hide_output_from_viewers": true
|
789 |
+
},
|
790 |
+
"gather": {
|
791 |
+
"logged": 1706411588249
|
792 |
+
}
|
793 |
+
}
|
794 |
+
},
|
795 |
+
{
|
796 |
+
"cell_type": "markdown",
|
797 |
+
"source": [
|
798 |
+
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
|
799 |
+
],
|
800 |
+
"metadata": {
|
801 |
+
"collapsed": false
|
802 |
+
}
|
803 |
+
},
|
804 |
+
{
|
805 |
+
"cell_type": "code",
|
806 |
+
"source": [
|
807 |
+
"# Creating a map of class names from class numbers\n",
|
808 |
+
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
|
809 |
+
],
|
810 |
+
"outputs": [],
|
811 |
+
"execution_count": null,
|
812 |
+
"metadata": {
|
813 |
+
"datalore": {
|
814 |
+
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
|
815 |
+
"type": "CODE",
|
816 |
+
"hide_input_from_viewers": true,
|
817 |
+
"hide_output_from_viewers": true
|
818 |
+
},
|
819 |
+
"gather": {
|
820 |
+
"logged": 1706411588638
|
821 |
+
}
|
822 |
+
}
|
823 |
+
},
|
824 |
+
{
|
825 |
+
"cell_type": "code",
|
826 |
+
"source": [
|
827 |
+
"true_label_idxs, pred_label_idxs = [], []\n",
|
828 |
+
"\n",
|
829 |
+
"for vals in true_bools:\n",
|
830 |
+
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
|
831 |
+
"for vals in pred_bools:\n",
|
832 |
+
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
|
833 |
+
],
|
834 |
+
"outputs": [],
|
835 |
+
"execution_count": null,
|
836 |
+
"metadata": {
|
837 |
+
"datalore": {
|
838 |
+
"node_id": "jH0S35dDteUch01sa6me6e",
|
839 |
+
"type": "CODE",
|
840 |
+
"hide_input_from_viewers": true,
|
841 |
+
"hide_output_from_viewers": true
|
842 |
+
},
|
843 |
+
"gather": {
|
844 |
+
"logged": 1706411589004
|
845 |
+
}
|
846 |
+
}
|
847 |
+
},
|
848 |
+
{
|
849 |
+
"cell_type": "code",
|
850 |
+
"source": [
|
851 |
+
"true_label_texts, pred_label_texts = [], []\n",
|
852 |
+
"\n",
|
853 |
+
"for vals in true_label_idxs:\n",
|
854 |
+
" if vals:\n",
|
855 |
+
" true_label_texts.append([idx2label[val] for val in vals])\n",
|
856 |
+
" else:\n",
|
857 |
+
" true_label_texts.append(vals)\n",
|
858 |
+
"\n",
|
859 |
+
"for vals in pred_label_idxs:\n",
|
860 |
+
" if vals:\n",
|
861 |
+
" pred_label_texts.append([idx2label[val] for val in vals])\n",
|
862 |
+
" else:\n",
|
863 |
+
" pred_label_texts.append(vals)"
|
864 |
+
],
|
865 |
+
"outputs": [],
|
866 |
+
"execution_count": null,
|
867 |
+
"metadata": {
|
868 |
+
"datalore": {
|
869 |
+
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
|
870 |
+
"type": "CODE",
|
871 |
+
"hide_input_from_viewers": true,
|
872 |
+
"hide_output_from_viewers": true
|
873 |
+
},
|
874 |
+
"gather": {
|
875 |
+
"logged": 1706411589301
|
876 |
+
}
|
877 |
+
}
|
878 |
+
},
|
879 |
+
{
|
880 |
+
"cell_type": "code",
|
881 |
+
"source": [
|
882 |
+
"symptom_texts = [tokenizer.decode(text,\n",
|
883 |
+
" skip_special_tokens=True,\n",
|
884 |
+
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
|
885 |
+
],
|
886 |
+
"outputs": [],
|
887 |
+
"execution_count": null,
|
888 |
+
"metadata": {
|
889 |
+
"datalore": {
|
890 |
+
"node_id": "SxUmVHfQISEeptg1SawOmB",
|
891 |
+
"type": "CODE",
|
892 |
+
"hide_input_from_viewers": true,
|
893 |
+
"hide_output_from_viewers": true
|
894 |
+
},
|
895 |
+
"gather": {
|
896 |
+
"logged": 1706411591952
|
897 |
+
}
|
898 |
+
}
|
899 |
+
},
|
900 |
+
{
|
901 |
+
"cell_type": "code",
|
902 |
+
"source": [
|
903 |
+
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
|
904 |
+
" 'true_labels': true_label_texts, \n",
|
905 |
+
" 'pred_labels':pred_label_texts})\n",
|
906 |
+
"comparisons_df.to_csv('comparisons.csv')\n",
|
907 |
+
"comparisons_df"
|
908 |
+
],
|
909 |
+
"outputs": [],
|
910 |
+
"execution_count": null,
|
911 |
+
"metadata": {
|
912 |
+
"datalore": {
|
913 |
+
"node_id": "BxFNigNGRLTOqraI55BPSH",
|
914 |
+
"type": "CODE",
|
915 |
+
"hide_input_from_viewers": true,
|
916 |
+
"hide_output_from_viewers": true
|
917 |
+
},
|
918 |
+
"gather": {
|
919 |
+
"logged": 1706411592512
|
920 |
+
}
|
921 |
+
}
|
922 |
+
},
|
923 |
+
{
|
924 |
+
"cell_type": "markdown",
|
925 |
+
"source": [
|
926 |
+
"### Shapley analysis"
|
927 |
+
],
|
928 |
+
"metadata": {
|
929 |
+
"collapsed": false
|
930 |
+
}
|
931 |
+
},
|
932 |
+
{
|
933 |
+
"cell_type": "code",
|
934 |
+
"source": [
|
935 |
+
"explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
|
936 |
+
],
|
937 |
+
"outputs": [],
|
938 |
+
"execution_count": null,
|
939 |
+
"metadata": {
|
940 |
+
"datalore": {
|
941 |
+
"node_id": "OpdZcoenX2HwzLdai7K5UA",
|
942 |
+
"type": "CODE",
|
943 |
+
"hide_input_from_viewers": true,
|
944 |
+
"hide_output_from_viewers": true
|
945 |
+
},
|
946 |
+
"gather": {
|
947 |
+
"logged": 1706415109071
|
948 |
+
}
|
949 |
+
}
|
950 |
+
},
|
951 |
+
{
|
952 |
+
"cell_type": "markdown",
|
953 |
+
"source": [
|
954 |
+
"#### Sampling correct predictions\n",
|
955 |
+
"\n",
|
956 |
+
"First, let's look at some correct predictions of deaths:"
|
957 |
+
],
|
958 |
+
"metadata": {
|
959 |
+
"nteract": {
|
960 |
+
"transient": {
|
961 |
+
"deleting": false
|
962 |
+
}
|
963 |
+
}
|
964 |
+
}
|
965 |
+
},
|
966 |
+
{
|
967 |
+
"cell_type": "code",
|
968 |
+
"source": [
|
969 |
+
"correct_death_predictions = comparisons_df[comparisons_df['true_labels'].astype(str) == \"['DIED']\"]"
|
970 |
+
],
|
971 |
+
"outputs": [],
|
972 |
+
"execution_count": null,
|
973 |
+
"metadata": {
|
974 |
+
"jupyter": {
|
975 |
+
"source_hidden": false,
|
976 |
+
"outputs_hidden": false
|
977 |
+
},
|
978 |
+
"nteract": {
|
979 |
+
"transient": {
|
980 |
+
"deleting": false
|
981 |
+
}
|
982 |
+
},
|
983 |
+
"gather": {
|
984 |
+
"logged": 1706414973990
|
985 |
+
}
|
986 |
+
}
|
987 |
+
},
|
988 |
+
{
|
989 |
+
"cell_type": "code",
|
990 |
+
"source": [
|
991 |
+
"texts = [i[:512] for i in correct_death_predictions.sample(n=6).symptom_text]\n",
|
992 |
+
"idxs = [i for i in range(len(texts))]\n",
|
993 |
+
"\n",
|
994 |
+
"d_s = Dataset(Table.from_arrays([idxs, texts], names=[\"idx\", \"texts\"]))"
|
995 |
+
],
|
996 |
+
"outputs": [],
|
997 |
+
"execution_count": null,
|
998 |
+
"metadata": {
|
999 |
+
"jupyter": {
|
1000 |
+
"source_hidden": false,
|
1001 |
+
"outputs_hidden": false
|
1002 |
+
},
|
1003 |
+
"nteract": {
|
1004 |
+
"transient": {
|
1005 |
+
"deleting": false
|
1006 |
+
}
|
1007 |
+
},
|
1008 |
+
"gather": {
|
1009 |
+
"logged": 1706415114683
|
1010 |
+
}
|
1011 |
+
}
|
1012 |
+
},
|
1013 |
+
{
|
1014 |
+
"cell_type": "code",
|
1015 |
+
"source": [
|
1016 |
+
"shap_values = explainer(d_s[\"texts\"])"
|
1017 |
+
],
|
1018 |
+
"outputs": [],
|
1019 |
+
"execution_count": null,
|
1020 |
+
"metadata": {
|
1021 |
+
"jupyter": {
|
1022 |
+
"source_hidden": false,
|
1023 |
+
"outputs_hidden": false
|
1024 |
+
},
|
1025 |
+
"nteract": {
|
1026 |
+
"transient": {
|
1027 |
+
"deleting": false
|
1028 |
+
}
|
1029 |
+
},
|
1030 |
+
"gather": {
|
1031 |
+
"logged": 1706415129229
|
1032 |
+
}
|
1033 |
+
}
|
1034 |
+
},
|
1035 |
+
{
|
1036 |
+
"cell_type": "code",
|
1037 |
+
"source": [
|
1038 |
+
"shap.plots.text(shap_values)"
|
1039 |
+
],
|
1040 |
+
"outputs": [],
|
1041 |
+
"execution_count": null,
|
1042 |
+
"metadata": {
|
1043 |
+
"jupyter": {
|
1044 |
+
"source_hidden": false,
|
1045 |
+
"outputs_hidden": false
|
1046 |
+
},
|
1047 |
+
"nteract": {
|
1048 |
+
"transient": {
|
1049 |
+
"deleting": false
|
1050 |
+
}
|
1051 |
+
},
|
1052 |
+
"gather": {
|
1053 |
+
"logged": 1706415151494
|
1054 |
+
}
|
1055 |
+
}
|
1056 |
+
},
|
1057 |
+
{
|
1058 |
+
"cell_type": "code",
|
1059 |
+
"source": [],
|
1060 |
+
"outputs": [],
|
1061 |
+
"execution_count": null,
|
1062 |
+
"metadata": {
|
1063 |
+
"jupyter": {
|
1064 |
+
"source_hidden": false,
|
1065 |
+
"outputs_hidden": false
|
1066 |
+
},
|
1067 |
+
"nteract": {
|
1068 |
+
"transient": {
|
1069 |
+
"deleting": false
|
1070 |
+
}
|
1071 |
+
}
|
1072 |
+
}
|
1073 |
+
}
|
1074 |
+
],
|
1075 |
+
"metadata": {
|
1076 |
+
"kernelspec": {
|
1077 |
+
"name": "python3",
|
1078 |
+
"language": "python",
|
1079 |
+
"display_name": "Python 3 (ipykernel)"
|
1080 |
+
},
|
1081 |
+
"datalore": {
|
1082 |
+
"computation_mode": "JUPYTER",
|
1083 |
+
"package_manager": "pip",
|
1084 |
+
"base_environment": "default",
|
1085 |
+
"packages": [
|
1086 |
+
{
|
1087 |
+
"name": "datasets",
|
1088 |
+
"version": "2.16.1",
|
1089 |
+
"source": "PIP"
|
1090 |
+
},
|
1091 |
+
{
|
1092 |
+
"name": "torch",
|
1093 |
+
"version": "2.1.2",
|
1094 |
+
"source": "PIP"
|
1095 |
+
},
|
1096 |
+
{
|
1097 |
+
"name": "accelerate",
|
1098 |
+
"version": "0.26.1",
|
1099 |
+
"source": "PIP"
|
1100 |
+
}
|
1101 |
+
],
|
1102 |
+
"report_row_ids": [
|
1103 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
1104 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
1105 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
1106 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
1107 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
1108 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
1109 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
1110 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
1111 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
1112 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
1113 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
1114 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
1115 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
1116 |
+
],
|
1117 |
+
"version": 3
|
1118 |
+
},
|
1119 |
+
"microsoft": {
|
1120 |
+
"ms_spell_check": {
|
1121 |
+
"ms_spell_check_language": "en"
|
1122 |
+
},
|
1123 |
+
"host": {
|
1124 |
+
"AzureML": {
|
1125 |
+
"notebookHasBeenCompleted": true
|
1126 |
+
}
|
1127 |
+
}
|
1128 |
+
},
|
1129 |
+
"language_info": {
|
1130 |
+
"name": "python",
|
1131 |
+
"version": "3.8.5",
|
1132 |
+
"mimetype": "text/x-python",
|
1133 |
+
"codemirror_mode": {
|
1134 |
+
"name": "ipython",
|
1135 |
+
"version": 3
|
1136 |
+
},
|
1137 |
+
"pygments_lexer": "ipython3",
|
1138 |
+
"nbconvert_exporter": "python",
|
1139 |
+
"file_extension": ".py"
|
1140 |
+
},
|
1141 |
+
"nteract": {
|
1142 |
+
"version": "nteract-front-end@1.0.0"
|
1143 |
+
}
|
1144 |
+
},
|
1145 |
+
"nbformat": 4,
|
1146 |
+
"nbformat_minor": 4
|
1147 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-15-7-36Z.ipynb
ADDED
@@ -0,0 +1,1452 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
8 |
+
"\n",
|
9 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"nteract": {
|
17 |
+
"transient": {
|
18 |
+
"deleting": false
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"tags": []
|
22 |
+
},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"# %pip install accelerate -U"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 2,
|
31 |
+
"metadata": {
|
32 |
+
"nteract": {
|
33 |
+
"transient": {
|
34 |
+
"deleting": false
|
35 |
+
}
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"outputs": [],
|
39 |
+
"source": [
|
40 |
+
"# %pip install transformers datasets shap watermark wandb"
|
41 |
+
]
|
42 |
+
},
|
43 |
+
{
|
44 |
+
"cell_type": "code",
|
45 |
+
"execution_count": 3,
|
46 |
+
"metadata": {
|
47 |
+
"datalore": {
|
48 |
+
"hide_input_from_viewers": false,
|
49 |
+
"hide_output_from_viewers": false,
|
50 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
51 |
+
"report_properties": {
|
52 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
53 |
+
},
|
54 |
+
"type": "CODE"
|
55 |
+
},
|
56 |
+
"gather": {
|
57 |
+
"logged": 1706449625034
|
58 |
+
},
|
59 |
+
"tags": []
|
60 |
+
},
|
61 |
+
"outputs": [
|
62 |
+
{
|
63 |
+
"name": "stderr",
|
64 |
+
"output_type": "stream",
|
65 |
+
"text": [
|
66 |
+
"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
67 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
68 |
+
"2024-01-28 14:18:31.729214: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
69 |
+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
70 |
+
"2024-01-28 14:18:32.746966: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
|
71 |
+
"2024-01-28 14:18:32.747096: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
|
72 |
+
"2024-01-28 14:18:32.747111: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
73 |
+
]
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"source": [
|
77 |
+
"import pandas as pd\n",
|
78 |
+
"import numpy as np\n",
|
79 |
+
"import torch\n",
|
80 |
+
"import os\n",
|
81 |
+
"from typing import List\n",
|
82 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
83 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
|
84 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
85 |
+
"from pyarrow import Table\n",
|
86 |
+
"import shap\n",
|
87 |
+
"import wandb\n",
|
88 |
+
"\n",
|
89 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
90 |
+
"\n",
|
91 |
+
"%load_ext watermark"
|
92 |
+
]
|
93 |
+
},
|
94 |
+
{
|
95 |
+
"cell_type": "code",
|
96 |
+
"execution_count": 4,
|
97 |
+
"metadata": {
|
98 |
+
"collapsed": false,
|
99 |
+
"gather": {
|
100 |
+
"logged": 1706449721319
|
101 |
+
},
|
102 |
+
"jupyter": {
|
103 |
+
"outputs_hidden": false
|
104 |
+
}
|
105 |
+
},
|
106 |
+
"outputs": [],
|
107 |
+
"source": [
|
108 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
109 |
+
"\n",
|
110 |
+
"SEED: int = 42\n",
|
111 |
+
"\n",
|
112 |
+
"BATCH_SIZE: int = 16\n",
|
113 |
+
"EPOCHS: int = 3\n",
|
114 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
115 |
+
"\n",
|
116 |
+
"CLASS_NAMES: List[str] = [\"DIED\",\n",
|
117 |
+
" \"ER_VISIT\",\n",
|
118 |
+
" \"HOSPITAL\",\n",
|
119 |
+
" \"OFC_VISIT\",\n",
|
120 |
+
" #\"X_STAY\", # pruned\n",
|
121 |
+
" #\"DISABLE\", # pruned\n",
|
122 |
+
" #\"D_PRESENTED\" # pruned\n",
|
123 |
+
" ]\n",
|
124 |
+
"\n",
|
125 |
+
"\n",
|
126 |
+
"\n",
|
127 |
+
"\n",
|
128 |
+
"# WandB configuration\n",
|
129 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
|
130 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
131 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
132 |
+
]
|
133 |
+
},
|
134 |
+
{
|
135 |
+
"cell_type": "code",
|
136 |
+
"execution_count": 5,
|
137 |
+
"metadata": {
|
138 |
+
"collapsed": false,
|
139 |
+
"jupyter": {
|
140 |
+
"outputs_hidden": false
|
141 |
+
}
|
142 |
+
},
|
143 |
+
"outputs": [
|
144 |
+
{
|
145 |
+
"name": "stdout",
|
146 |
+
"output_type": "stream",
|
147 |
+
"text": [
|
148 |
+
"torch : 1.12.0\n",
|
149 |
+
"pandas : 2.0.2\n",
|
150 |
+
"numpy : 1.23.5\n",
|
151 |
+
"shap : 0.44.1\n",
|
152 |
+
"re : 2.2.1\n",
|
153 |
+
"wandb : 0.16.2\n",
|
154 |
+
"logging: 0.5.1.2\n",
|
155 |
+
"\n"
|
156 |
+
]
|
157 |
+
}
|
158 |
+
],
|
159 |
+
"source": [
|
160 |
+
"%watermark --iversion"
|
161 |
+
]
|
162 |
+
},
|
163 |
+
{
|
164 |
+
"cell_type": "code",
|
165 |
+
"execution_count": 6,
|
166 |
+
"metadata": {
|
167 |
+
"datalore": {
|
168 |
+
"hide_input_from_viewers": true,
|
169 |
+
"hide_output_from_viewers": true,
|
170 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
171 |
+
"type": "CODE"
|
172 |
+
}
|
173 |
+
},
|
174 |
+
"outputs": [
|
175 |
+
{
|
176 |
+
"name": "stdout",
|
177 |
+
"output_type": "stream",
|
178 |
+
"text": [
|
179 |
+
"Sun Jan 28 14:18:35 2024 \n",
|
180 |
+
"+---------------------------------------------------------------------------------------+\n",
|
181 |
+
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
|
182 |
+
"|-----------------------------------------+----------------------+----------------------+\n",
|
183 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
184 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
185 |
+
"| | | MIG M. |\n",
|
186 |
+
"|=========================================+======================+======================|\n",
|
187 |
+
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
|
188 |
+
"| N/A 29C P0 27W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
189 |
+
"| | | N/A |\n",
|
190 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
191 |
+
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
|
192 |
+
"| N/A 29C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
193 |
+
"| | | N/A |\n",
|
194 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
195 |
+
" \n",
|
196 |
+
"+---------------------------------------------------------------------------------------+\n",
|
197 |
+
"| Processes: |\n",
|
198 |
+
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
199 |
+
"| ID ID Usage |\n",
|
200 |
+
"|=======================================================================================|\n",
|
201 |
+
"| No running processes found |\n",
|
202 |
+
"+---------------------------------------------------------------------------------------+\n"
|
203 |
+
]
|
204 |
+
}
|
205 |
+
],
|
206 |
+
"source": [
|
207 |
+
"!nvidia-smi"
|
208 |
+
]
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "markdown",
|
212 |
+
"metadata": {
|
213 |
+
"datalore": {
|
214 |
+
"hide_input_from_viewers": false,
|
215 |
+
"hide_output_from_viewers": false,
|
216 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
217 |
+
"report_properties": {
|
218 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
219 |
+
},
|
220 |
+
"type": "MD"
|
221 |
+
}
|
222 |
+
},
|
223 |
+
"source": [
|
224 |
+
"## Loading the data set"
|
225 |
+
]
|
226 |
+
},
|
227 |
+
{
|
228 |
+
"cell_type": "code",
|
229 |
+
"execution_count": 7,
|
230 |
+
"metadata": {
|
231 |
+
"collapsed": false,
|
232 |
+
"gather": {
|
233 |
+
"logged": 1706449040507
|
234 |
+
},
|
235 |
+
"jupyter": {
|
236 |
+
"outputs_hidden": false
|
237 |
+
}
|
238 |
+
},
|
239 |
+
"outputs": [],
|
240 |
+
"source": [
|
241 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
242 |
+
]
|
243 |
+
},
|
244 |
+
{
|
245 |
+
"cell_type": "code",
|
246 |
+
"execution_count": 8,
|
247 |
+
"metadata": {
|
248 |
+
"collapsed": false,
|
249 |
+
"gather": {
|
250 |
+
"logged": 1706449044205
|
251 |
+
},
|
252 |
+
"jupyter": {
|
253 |
+
"outputs_hidden": false,
|
254 |
+
"source_hidden": false
|
255 |
+
},
|
256 |
+
"nteract": {
|
257 |
+
"transient": {
|
258 |
+
"deleting": false
|
259 |
+
}
|
260 |
+
}
|
261 |
+
},
|
262 |
+
"outputs": [
|
263 |
+
{
|
264 |
+
"data": {
|
265 |
+
"text/plain": [
|
266 |
+
"DatasetDict({\n",
|
267 |
+
" train: Dataset({\n",
|
268 |
+
" features: ['id', 'text', 'labels'],\n",
|
269 |
+
" num_rows: 1270444\n",
|
270 |
+
" })\n",
|
271 |
+
" test: Dataset({\n",
|
272 |
+
" features: ['id', 'text', 'labels'],\n",
|
273 |
+
" num_rows: 272238\n",
|
274 |
+
" })\n",
|
275 |
+
" val: Dataset({\n",
|
276 |
+
" features: ['id', 'text', 'labels'],\n",
|
277 |
+
" num_rows: 272238\n",
|
278 |
+
" })\n",
|
279 |
+
"})"
|
280 |
+
]
|
281 |
+
},
|
282 |
+
"execution_count": 8,
|
283 |
+
"metadata": {},
|
284 |
+
"output_type": "execute_result"
|
285 |
+
}
|
286 |
+
],
|
287 |
+
"source": [
|
288 |
+
"dataset"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 9,
|
294 |
+
"metadata": {},
|
295 |
+
"outputs": [],
|
296 |
+
"source": [
|
297 |
+
"SUBSAMPLING: float = 0.1"
|
298 |
+
]
|
299 |
+
},
|
300 |
+
{
|
301 |
+
"cell_type": "code",
|
302 |
+
"execution_count": 10,
|
303 |
+
"metadata": {
|
304 |
+
"collapsed": false,
|
305 |
+
"gather": {
|
306 |
+
"logged": 1706449378281
|
307 |
+
},
|
308 |
+
"jupyter": {
|
309 |
+
"outputs_hidden": false,
|
310 |
+
"source_hidden": false
|
311 |
+
},
|
312 |
+
"nteract": {
|
313 |
+
"transient": {
|
314 |
+
"deleting": false
|
315 |
+
}
|
316 |
+
}
|
317 |
+
},
|
318 |
+
"outputs": [],
|
319 |
+
"source": [
|
320 |
+
"def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n",
|
321 |
+
" res = DatasetDict()\n",
|
322 |
+
"\n",
|
323 |
+
" res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n",
|
324 |
+
" res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n",
|
325 |
+
" res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n",
|
326 |
+
" \n",
|
327 |
+
" return res"
|
328 |
+
]
|
329 |
+
},
|
330 |
+
{
|
331 |
+
"cell_type": "code",
|
332 |
+
"execution_count": 11,
|
333 |
+
"metadata": {
|
334 |
+
"collapsed": false,
|
335 |
+
"gather": {
|
336 |
+
"logged": 1706449384162
|
337 |
+
},
|
338 |
+
"jupyter": {
|
339 |
+
"outputs_hidden": false,
|
340 |
+
"source_hidden": false
|
341 |
+
},
|
342 |
+
"nteract": {
|
343 |
+
"transient": {
|
344 |
+
"deleting": false
|
345 |
+
}
|
346 |
+
}
|
347 |
+
},
|
348 |
+
"outputs": [],
|
349 |
+
"source": [
|
350 |
+
"dataset = minisample(dataset, SUBSAMPLING)"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": 12,
|
356 |
+
"metadata": {
|
357 |
+
"collapsed": false,
|
358 |
+
"gather": {
|
359 |
+
"logged": 1706449387981
|
360 |
+
},
|
361 |
+
"jupyter": {
|
362 |
+
"outputs_hidden": false,
|
363 |
+
"source_hidden": false
|
364 |
+
},
|
365 |
+
"nteract": {
|
366 |
+
"transient": {
|
367 |
+
"deleting": false
|
368 |
+
}
|
369 |
+
}
|
370 |
+
},
|
371 |
+
"outputs": [
|
372 |
+
{
|
373 |
+
"data": {
|
374 |
+
"text/plain": [
|
375 |
+
"DatasetDict({\n",
|
376 |
+
" train: Dataset({\n",
|
377 |
+
" features: ['id', 'text', 'labels'],\n",
|
378 |
+
" num_rows: 127044\n",
|
379 |
+
" })\n",
|
380 |
+
" test: Dataset({\n",
|
381 |
+
" features: ['id', 'text', 'labels'],\n",
|
382 |
+
" num_rows: 27224\n",
|
383 |
+
" })\n",
|
384 |
+
" val: Dataset({\n",
|
385 |
+
" features: ['id', 'text', 'labels'],\n",
|
386 |
+
" num_rows: 27224\n",
|
387 |
+
" })\n",
|
388 |
+
"})"
|
389 |
+
]
|
390 |
+
},
|
391 |
+
"execution_count": 12,
|
392 |
+
"metadata": {},
|
393 |
+
"output_type": "execute_result"
|
394 |
+
}
|
395 |
+
],
|
396 |
+
"source": [
|
397 |
+
"dataset"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "markdown",
|
402 |
+
"metadata": {
|
403 |
+
"nteract": {
|
404 |
+
"transient": {
|
405 |
+
"deleting": false
|
406 |
+
}
|
407 |
+
}
|
408 |
+
},
|
409 |
+
"source": [
|
410 |
+
"We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
|
411 |
+
]
|
412 |
+
},
|
413 |
+
{
|
414 |
+
"cell_type": "code",
|
415 |
+
"execution_count": 13,
|
416 |
+
"metadata": {
|
417 |
+
"collapsed": false,
|
418 |
+
"gather": {
|
419 |
+
"logged": 1706449443055
|
420 |
+
},
|
421 |
+
"jupyter": {
|
422 |
+
"outputs_hidden": false,
|
423 |
+
"source_hidden": false
|
424 |
+
},
|
425 |
+
"nteract": {
|
426 |
+
"transient": {
|
427 |
+
"deleting": false
|
428 |
+
}
|
429 |
+
}
|
430 |
+
},
|
431 |
+
"outputs": [],
|
432 |
+
"source": [
|
433 |
+
"ds = DatasetDict()\n",
|
434 |
+
"\n",
|
435 |
+
"for i in [\"test\", \"train\", \"val\"]:\n",
|
436 |
+
" tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
|
437 |
+
" ds[i] = Dataset(tab)\n",
|
438 |
+
"\n",
|
439 |
+
"dataset = ds"
|
440 |
+
]
|
441 |
+
},
|
442 |
+
{
|
443 |
+
"cell_type": "markdown",
|
444 |
+
"metadata": {},
|
445 |
+
"source": [
|
446 |
+
"### Tokenisation and encoding"
|
447 |
+
]
|
448 |
+
},
|
449 |
+
{
|
450 |
+
"cell_type": "code",
|
451 |
+
"execution_count": 14,
|
452 |
+
"metadata": {
|
453 |
+
"datalore": {
|
454 |
+
"hide_input_from_viewers": true,
|
455 |
+
"hide_output_from_viewers": true,
|
456 |
+
"node_id": "I7n646PIscsUZRoHu6m7zm",
|
457 |
+
"type": "CODE"
|
458 |
+
},
|
459 |
+
"gather": {
|
460 |
+
"logged": 1706449638377
|
461 |
+
}
|
462 |
+
},
|
463 |
+
"outputs": [],
|
464 |
+
"source": [
|
465 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
|
466 |
+
]
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"cell_type": "code",
|
470 |
+
"execution_count": 15,
|
471 |
+
"metadata": {
|
472 |
+
"datalore": {
|
473 |
+
"hide_input_from_viewers": true,
|
474 |
+
"hide_output_from_viewers": true,
|
475 |
+
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
|
476 |
+
"type": "CODE"
|
477 |
+
},
|
478 |
+
"gather": {
|
479 |
+
"logged": 1706449642580
|
480 |
+
}
|
481 |
+
},
|
482 |
+
"outputs": [],
|
483 |
+
"source": [
|
484 |
+
"def tokenize_and_encode(examples):\n",
|
485 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
486 |
+
]
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"cell_type": "code",
|
490 |
+
"execution_count": 16,
|
491 |
+
"metadata": {
|
492 |
+
"datalore": {
|
493 |
+
"hide_input_from_viewers": true,
|
494 |
+
"hide_output_from_viewers": true,
|
495 |
+
"node_id": "slHeNysZOX9uWS9PB7jFDb",
|
496 |
+
"type": "CODE"
|
497 |
+
},
|
498 |
+
"gather": {
|
499 |
+
"logged": 1706449721161
|
500 |
+
}
|
501 |
+
},
|
502 |
+
"outputs": [
|
503 |
+
{
|
504 |
+
"name": "stderr",
|
505 |
+
"output_type": "stream",
|
506 |
+
"text": [
|
507 |
+
"Map: 100%|██████████| 27224/27224 [00:10<00:00, 2638.52 examples/s]\n",
|
508 |
+
"Map: 100%|██████████| 127044/127044 [00:48<00:00, 2633.40 examples/s]\n",
|
509 |
+
"Map: 100%|██████████| 27224/27224 [00:10<00:00, 2613.19 examples/s]\n"
|
510 |
+
]
|
511 |
+
}
|
512 |
+
],
|
513 |
+
"source": [
|
514 |
+
"cols = dataset[\"train\"].column_names\n",
|
515 |
+
"cols.remove(\"labels\")\n",
|
516 |
+
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
517 |
+
]
|
518 |
+
},
|
519 |
+
{
|
520 |
+
"cell_type": "markdown",
|
521 |
+
"metadata": {},
|
522 |
+
"source": [
|
523 |
+
"### Training"
|
524 |
+
]
|
525 |
+
},
|
526 |
+
{
|
527 |
+
"cell_type": "code",
|
528 |
+
"execution_count": 17,
|
529 |
+
"metadata": {
|
530 |
+
"datalore": {
|
531 |
+
"hide_input_from_viewers": true,
|
532 |
+
"hide_output_from_viewers": true,
|
533 |
+
"node_id": "itXWkbDw9sqbkMuDP84QoT",
|
534 |
+
"type": "CODE"
|
535 |
+
},
|
536 |
+
"gather": {
|
537 |
+
"logged": 1706449743072
|
538 |
+
}
|
539 |
+
},
|
540 |
+
"outputs": [],
|
541 |
+
"source": [
|
542 |
+
"class MultiLabelTrainer(Trainer):\n",
|
543 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
544 |
+
" labels = inputs.pop(\"labels\")\n",
|
545 |
+
" outputs = model(**inputs)\n",
|
546 |
+
" logits = outputs.logits\n",
|
547 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
548 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
549 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
550 |
+
" return (loss, outputs) if return_outputs else loss"
|
551 |
+
]
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"cell_type": "code",
|
555 |
+
"execution_count": 18,
|
556 |
+
"metadata": {
|
557 |
+
"datalore": {
|
558 |
+
"hide_input_from_viewers": true,
|
559 |
+
"hide_output_from_viewers": true,
|
560 |
+
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
|
561 |
+
"type": "CODE"
|
562 |
+
},
|
563 |
+
"gather": {
|
564 |
+
"logged": 1706449761205
|
565 |
+
}
|
566 |
+
},
|
567 |
+
"outputs": [
|
568 |
+
{
|
569 |
+
"name": "stderr",
|
570 |
+
"output_type": "stream",
|
571 |
+
"text": [
|
572 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
|
573 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
574 |
+
]
|
575 |
+
}
|
576 |
+
],
|
577 |
+
"source": [
|
578 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
|
579 |
+
]
|
580 |
+
},
|
581 |
+
{
|
582 |
+
"cell_type": "code",
|
583 |
+
"execution_count": 19,
|
584 |
+
"metadata": {
|
585 |
+
"datalore": {
|
586 |
+
"hide_input_from_viewers": true,
|
587 |
+
"hide_output_from_viewers": true,
|
588 |
+
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
|
589 |
+
"type": "CODE"
|
590 |
+
},
|
591 |
+
"gather": {
|
592 |
+
"logged": 1706449761541
|
593 |
+
}
|
594 |
+
},
|
595 |
+
"outputs": [],
|
596 |
+
"source": [
|
597 |
+
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
|
598 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
599 |
+
" y_true = torch.from_numpy(y_true)\n",
|
600 |
+
"\n",
|
601 |
+
" if sigmoid:\n",
|
602 |
+
" y_pred = y_pred.sigmoid()\n",
|
603 |
+
"\n",
|
604 |
+
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
|
605 |
+
]
|
606 |
+
},
|
607 |
+
{
|
608 |
+
"cell_type": "code",
|
609 |
+
"execution_count": 20,
|
610 |
+
"metadata": {
|
611 |
+
"datalore": {
|
612 |
+
"hide_input_from_viewers": true,
|
613 |
+
"hide_output_from_viewers": true,
|
614 |
+
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
|
615 |
+
"type": "CODE"
|
616 |
+
},
|
617 |
+
"gather": {
|
618 |
+
"logged": 1706449761720
|
619 |
+
}
|
620 |
+
},
|
621 |
+
"outputs": [],
|
622 |
+
"source": [
|
623 |
+
"def compute_metrics(eval_pred):\n",
|
624 |
+
" predictions, labels = eval_pred\n",
|
625 |
+
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"cell_type": "code",
|
630 |
+
"execution_count": 21,
|
631 |
+
"metadata": {
|
632 |
+
"datalore": {
|
633 |
+
"hide_input_from_viewers": true,
|
634 |
+
"hide_output_from_viewers": true,
|
635 |
+
"node_id": "1iPZOTKPwSkTgX5dORqT89",
|
636 |
+
"type": "CODE"
|
637 |
+
},
|
638 |
+
"gather": {
|
639 |
+
"logged": 1706449761893
|
640 |
+
}
|
641 |
+
},
|
642 |
+
"outputs": [],
|
643 |
+
"source": [
|
644 |
+
"args = TrainingArguments(\n",
|
645 |
+
" output_dir=\"vaers\",\n",
|
646 |
+
" evaluation_strategy=\"epoch\",\n",
|
647 |
+
" learning_rate=2e-5,\n",
|
648 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
649 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
650 |
+
" num_train_epochs=EPOCHS,\n",
|
651 |
+
" weight_decay=.01,\n",
|
652 |
+
" logging_steps=1,\n",
|
653 |
+
" run_name=f\"daedra-training\",\n",
|
654 |
+
" report_to=[\"wandb\"]\n",
|
655 |
+
")"
|
656 |
+
]
|
657 |
+
},
|
658 |
+
{
|
659 |
+
"cell_type": "code",
|
660 |
+
"execution_count": 22,
|
661 |
+
"metadata": {
|
662 |
+
"datalore": {
|
663 |
+
"hide_input_from_viewers": true,
|
664 |
+
"hide_output_from_viewers": true,
|
665 |
+
"node_id": "bnRkNvRYltLun6gCEgL7v0",
|
666 |
+
"type": "CODE"
|
667 |
+
},
|
668 |
+
"gather": {
|
669 |
+
"logged": 1706449769103
|
670 |
+
}
|
671 |
+
},
|
672 |
+
"outputs": [],
|
673 |
+
"source": [
|
674 |
+
"multi_label_trainer = MultiLabelTrainer(\n",
|
675 |
+
" model, \n",
|
676 |
+
" args, \n",
|
677 |
+
" train_dataset=ds_enc[\"train\"], \n",
|
678 |
+
" eval_dataset=ds_enc[\"test\"], \n",
|
679 |
+
" compute_metrics=compute_metrics, \n",
|
680 |
+
" tokenizer=tokenizer\n",
|
681 |
+
")"
|
682 |
+
]
|
683 |
+
},
|
684 |
+
{
|
685 |
+
"cell_type": "code",
|
686 |
+
"execution_count": 23,
|
687 |
+
"metadata": {
|
688 |
+
"datalore": {
|
689 |
+
"hide_input_from_viewers": true,
|
690 |
+
"hide_output_from_viewers": true,
|
691 |
+
"node_id": "LO54PlDkWQdFrzV25FvduB",
|
692 |
+
"type": "CODE"
|
693 |
+
},
|
694 |
+
"gather": {
|
695 |
+
"logged": 1706449880674
|
696 |
+
}
|
697 |
+
},
|
698 |
+
"outputs": [
|
699 |
+
{
|
700 |
+
"name": "stderr",
|
701 |
+
"output_type": "stream",
|
702 |
+
"text": [
|
703 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
|
704 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
705 |
+
]
|
706 |
+
},
|
707 |
+
{
|
708 |
+
"data": {
|
709 |
+
"text/html": [
|
710 |
+
"Tracking run with wandb version 0.16.2"
|
711 |
+
],
|
712 |
+
"text/plain": [
|
713 |
+
"<IPython.core.display.HTML object>"
|
714 |
+
]
|
715 |
+
},
|
716 |
+
"metadata": {},
|
717 |
+
"output_type": "display_data"
|
718 |
+
},
|
719 |
+
{
|
720 |
+
"data": {
|
721 |
+
"text/html": [
|
722 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141956-9lniqjvz</code>"
|
723 |
+
],
|
724 |
+
"text/plain": [
|
725 |
+
"<IPython.core.display.HTML object>"
|
726 |
+
]
|
727 |
+
},
|
728 |
+
"metadata": {},
|
729 |
+
"output_type": "display_data"
|
730 |
+
},
|
731 |
+
{
|
732 |
+
"data": {
|
733 |
+
"text/html": [
|
734 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
735 |
+
],
|
736 |
+
"text/plain": [
|
737 |
+
"<IPython.core.display.HTML object>"
|
738 |
+
]
|
739 |
+
},
|
740 |
+
"metadata": {},
|
741 |
+
"output_type": "display_data"
|
742 |
+
},
|
743 |
+
{
|
744 |
+
"data": {
|
745 |
+
"text/html": [
|
746 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
|
747 |
+
],
|
748 |
+
"text/plain": [
|
749 |
+
"<IPython.core.display.HTML object>"
|
750 |
+
]
|
751 |
+
},
|
752 |
+
"metadata": {},
|
753 |
+
"output_type": "display_data"
|
754 |
+
},
|
755 |
+
{
|
756 |
+
"data": {
|
757 |
+
"text/html": [
|
758 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a>"
|
759 |
+
],
|
760 |
+
"text/plain": [
|
761 |
+
"<IPython.core.display.HTML object>"
|
762 |
+
]
|
763 |
+
},
|
764 |
+
"metadata": {},
|
765 |
+
"output_type": "display_data"
|
766 |
+
},
|
767 |
+
{
|
768 |
+
"data": {
|
769 |
+
"text/html": [
|
770 |
+
"Finishing last run (ID:9lniqjvz) before initializing another..."
|
771 |
+
],
|
772 |
+
"text/plain": [
|
773 |
+
"<IPython.core.display.HTML object>"
|
774 |
+
]
|
775 |
+
},
|
776 |
+
"metadata": {},
|
777 |
+
"output_type": "display_data"
|
778 |
+
},
|
779 |
+
{
|
780 |
+
"data": {
|
781 |
+
"text/html": [
|
782 |
+
" View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
783 |
+
],
|
784 |
+
"text/plain": [
|
785 |
+
"<IPython.core.display.HTML object>"
|
786 |
+
]
|
787 |
+
},
|
788 |
+
"metadata": {},
|
789 |
+
"output_type": "display_data"
|
790 |
+
},
|
791 |
+
{
|
792 |
+
"data": {
|
793 |
+
"text/html": [
|
794 |
+
"Find logs at: <code>./wandb/run-20240128_141956-9lniqjvz/logs</code>"
|
795 |
+
],
|
796 |
+
"text/plain": [
|
797 |
+
"<IPython.core.display.HTML object>"
|
798 |
+
]
|
799 |
+
},
|
800 |
+
"metadata": {},
|
801 |
+
"output_type": "display_data"
|
802 |
+
},
|
803 |
+
{
|
804 |
+
"data": {
|
805 |
+
"text/html": [
|
806 |
+
"Successfully finished last run (ID:9lniqjvz). Initializing new run:<br/>"
|
807 |
+
],
|
808 |
+
"text/plain": [
|
809 |
+
"<IPython.core.display.HTML object>"
|
810 |
+
]
|
811 |
+
},
|
812 |
+
"metadata": {},
|
813 |
+
"output_type": "display_data"
|
814 |
+
},
|
815 |
+
{
|
816 |
+
"data": {
|
817 |
+
"text/html": [
|
818 |
+
"Tracking run with wandb version 0.16.2"
|
819 |
+
],
|
820 |
+
"text/plain": [
|
821 |
+
"<IPython.core.display.HTML object>"
|
822 |
+
]
|
823 |
+
},
|
824 |
+
"metadata": {},
|
825 |
+
"output_type": "display_data"
|
826 |
+
},
|
827 |
+
{
|
828 |
+
"data": {
|
829 |
+
"text/html": [
|
830 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141958-5idmkcie</code>"
|
831 |
+
],
|
832 |
+
"text/plain": [
|
833 |
+
"<IPython.core.display.HTML object>"
|
834 |
+
]
|
835 |
+
},
|
836 |
+
"metadata": {},
|
837 |
+
"output_type": "display_data"
|
838 |
+
},
|
839 |
+
{
|
840 |
+
"data": {
|
841 |
+
"text/html": [
|
842 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
843 |
+
],
|
844 |
+
"text/plain": [
|
845 |
+
"<IPython.core.display.HTML object>"
|
846 |
+
]
|
847 |
+
},
|
848 |
+
"metadata": {},
|
849 |
+
"output_type": "display_data"
|
850 |
+
},
|
851 |
+
{
|
852 |
+
"data": {
|
853 |
+
"text/html": [
|
854 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
|
855 |
+
],
|
856 |
+
"text/plain": [
|
857 |
+
"<IPython.core.display.HTML object>"
|
858 |
+
]
|
859 |
+
},
|
860 |
+
"metadata": {},
|
861 |
+
"output_type": "display_data"
|
862 |
+
},
|
863 |
+
{
|
864 |
+
"data": {
|
865 |
+
"text/html": [
|
866 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a>"
|
867 |
+
],
|
868 |
+
"text/plain": [
|
869 |
+
"<IPython.core.display.HTML object>"
|
870 |
+
]
|
871 |
+
},
|
872 |
+
"metadata": {},
|
873 |
+
"output_type": "display_data"
|
874 |
+
},
|
875 |
+
{
|
876 |
+
"data": {
|
877 |
+
"text/html": [
|
878 |
+
"\n",
|
879 |
+
" <div>\n",
|
880 |
+
" \n",
|
881 |
+
" <progress value='1003' max='851' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
882 |
+
" [851/851 26:26]\n",
|
883 |
+
" </div>\n",
|
884 |
+
" "
|
885 |
+
],
|
886 |
+
"text/plain": [
|
887 |
+
"<IPython.core.display.HTML object>"
|
888 |
+
]
|
889 |
+
},
|
890 |
+
"metadata": {},
|
891 |
+
"output_type": "display_data"
|
892 |
+
},
|
893 |
+
{
|
894 |
+
"data": {
|
895 |
+
"text/html": [
|
896 |
+
"<style>\n",
|
897 |
+
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
|
898 |
+
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
|
899 |
+
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
|
900 |
+
" </style>\n",
|
901 |
+
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>▁</td></tr><tr><td>eval/loss</td><td>▁</td></tr><tr><td>eval/runtime</td><td>▁</td></tr><tr><td>eval/samples_per_second</td><td>▁</td></tr><tr><td>eval/steps_per_second</td><td>▁</td></tr><tr><td>train/global_step</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>0.55198</td></tr><tr><td>eval/loss</td><td>0.68442</td></tr><tr><td>eval/runtime</td><td>105.0436</td></tr><tr><td>eval/samples_per_second</td><td>259.168</td></tr><tr><td>eval/steps_per_second</td><td>8.101</td></tr><tr><td>train/global_step</td><td>0</td></tr></table><br/></div></div>"
|
902 |
+
],
|
903 |
+
"text/plain": [
|
904 |
+
"<IPython.core.display.HTML object>"
|
905 |
+
]
|
906 |
+
},
|
907 |
+
"metadata": {},
|
908 |
+
"output_type": "display_data"
|
909 |
+
},
|
910 |
+
{
|
911 |
+
"data": {
|
912 |
+
"text/html": [
|
913 |
+
" View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
914 |
+
],
|
915 |
+
"text/plain": [
|
916 |
+
"<IPython.core.display.HTML object>"
|
917 |
+
]
|
918 |
+
},
|
919 |
+
"metadata": {},
|
920 |
+
"output_type": "display_data"
|
921 |
+
},
|
922 |
+
{
|
923 |
+
"data": {
|
924 |
+
"text/html": [
|
925 |
+
"Find logs at: <code>./wandb/run-20240128_141958-5idmkcie/logs</code>"
|
926 |
+
],
|
927 |
+
"text/plain": [
|
928 |
+
"<IPython.core.display.HTML object>"
|
929 |
+
]
|
930 |
+
},
|
931 |
+
"metadata": {},
|
932 |
+
"output_type": "display_data"
|
933 |
+
}
|
934 |
+
],
|
935 |
+
"source": [
|
936 |
+
"if SUBSAMPLING != 1.0:\n",
|
937 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
938 |
+
"else:\n",
|
939 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
940 |
+
" \n",
|
941 |
+
"wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n",
|
942 |
+
"\n",
|
943 |
+
"multi_label_trainer.evaluate()\n",
|
944 |
+
"wandb.finish()"
|
945 |
+
]
|
946 |
+
},
|
947 |
+
{
|
948 |
+
"cell_type": "code",
|
949 |
+
"execution_count": null,
|
950 |
+
"metadata": {
|
951 |
+
"datalore": {
|
952 |
+
"hide_input_from_viewers": true,
|
953 |
+
"hide_output_from_viewers": true,
|
954 |
+
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
|
955 |
+
"type": "CODE"
|
956 |
+
},
|
957 |
+
"gather": {
|
958 |
+
"logged": 1706449934637
|
959 |
+
}
|
960 |
+
},
|
961 |
+
"outputs": [
|
962 |
+
{
|
963 |
+
"data": {
|
964 |
+
"text/html": [
|
965 |
+
"Tracking run with wandb version 0.16.2"
|
966 |
+
],
|
967 |
+
"text/plain": [
|
968 |
+
"<IPython.core.display.HTML object>"
|
969 |
+
]
|
970 |
+
},
|
971 |
+
"metadata": {},
|
972 |
+
"output_type": "display_data"
|
973 |
+
},
|
974 |
+
{
|
975 |
+
"data": {
|
976 |
+
"text/html": [
|
977 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_142151-2mcc0ibc</code>"
|
978 |
+
],
|
979 |
+
"text/plain": [
|
980 |
+
"<IPython.core.display.HTML object>"
|
981 |
+
]
|
982 |
+
},
|
983 |
+
"metadata": {},
|
984 |
+
"output_type": "display_data"
|
985 |
+
},
|
986 |
+
{
|
987 |
+
"data": {
|
988 |
+
"text/html": [
|
989 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
990 |
+
],
|
991 |
+
"text/plain": [
|
992 |
+
"<IPython.core.display.HTML object>"
|
993 |
+
]
|
994 |
+
},
|
995 |
+
"metadata": {},
|
996 |
+
"output_type": "display_data"
|
997 |
+
},
|
998 |
+
{
|
999 |
+
"data": {
|
1000 |
+
"text/html": [
|
1001 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
|
1002 |
+
],
|
1003 |
+
"text/plain": [
|
1004 |
+
"<IPython.core.display.HTML object>"
|
1005 |
+
]
|
1006 |
+
},
|
1007 |
+
"metadata": {},
|
1008 |
+
"output_type": "display_data"
|
1009 |
+
},
|
1010 |
+
{
|
1011 |
+
"data": {
|
1012 |
+
"text/html": [
|
1013 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc</a>"
|
1014 |
+
],
|
1015 |
+
"text/plain": [
|
1016 |
+
"<IPython.core.display.HTML object>"
|
1017 |
+
]
|
1018 |
+
},
|
1019 |
+
"metadata": {},
|
1020 |
+
"output_type": "display_data"
|
1021 |
+
},
|
1022 |
+
{
|
1023 |
+
"data": {
|
1024 |
+
"text/html": [
|
1025 |
+
"\n",
|
1026 |
+
" <div>\n",
|
1027 |
+
" \n",
|
1028 |
+
" <progress value='3972' max='11913' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
1029 |
+
" [ 3972/11913 24:20 < 48:40, 2.72 it/s, Epoch 1/3]\n",
|
1030 |
+
" </div>\n",
|
1031 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
1032 |
+
" <thead>\n",
|
1033 |
+
" <tr style=\"text-align: left;\">\n",
|
1034 |
+
" <th>Epoch</th>\n",
|
1035 |
+
" <th>Training Loss</th>\n",
|
1036 |
+
" <th>Validation Loss</th>\n",
|
1037 |
+
" </tr>\n",
|
1038 |
+
" </thead>\n",
|
1039 |
+
" <tbody>\n",
|
1040 |
+
" </tbody>\n",
|
1041 |
+
"</table><p>"
|
1042 |
+
],
|
1043 |
+
"text/plain": [
|
1044 |
+
"<IPython.core.display.HTML object>"
|
1045 |
+
]
|
1046 |
+
},
|
1047 |
+
"metadata": {},
|
1048 |
+
"output_type": "display_data"
|
1049 |
+
},
|
1050 |
+
{
|
1051 |
+
"name": "stderr",
|
1052 |
+
"output_type": "stream",
|
1053 |
+
"text": [
|
1054 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.6s\n",
|
1055 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 22.7s\n",
|
1056 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 14.0s\n",
|
1057 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 15.2s\n",
|
1058 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 14.0s\n",
|
1059 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 12.4s\n",
|
1060 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 13.4s\n"
|
1061 |
+
]
|
1062 |
+
}
|
1063 |
+
],
|
1064 |
+
"source": [
|
1065 |
+
"if SUBSAMPLING != 1.0:\n",
|
1066 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
1067 |
+
"else:\n",
|
1068 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
1069 |
+
" \n",
|
1070 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n",
|
1071 |
+
"\n",
|
1072 |
+
"multi_label_trainer.train()\n",
|
1073 |
+
"wandb.finish()"
|
1074 |
+
]
|
1075 |
+
},
|
1076 |
+
{
|
1077 |
+
"cell_type": "markdown",
|
1078 |
+
"metadata": {},
|
1079 |
+
"source": [
|
1080 |
+
"### Evaluation"
|
1081 |
+
]
|
1082 |
+
},
|
1083 |
+
{
|
1084 |
+
"cell_type": "markdown",
|
1085 |
+
"metadata": {},
|
1086 |
+
"source": [
|
1087 |
+
"We instantiate a classifier `pipeline` and push it to CUDA."
|
1088 |
+
]
|
1089 |
+
},
|
1090 |
+
{
|
1091 |
+
"cell_type": "code",
|
1092 |
+
"execution_count": null,
|
1093 |
+
"metadata": {
|
1094 |
+
"datalore": {
|
1095 |
+
"hide_input_from_viewers": true,
|
1096 |
+
"hide_output_from_viewers": true,
|
1097 |
+
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
|
1098 |
+
"type": "CODE"
|
1099 |
+
},
|
1100 |
+
"gather": {
|
1101 |
+
"logged": 1706411459928
|
1102 |
+
}
|
1103 |
+
},
|
1104 |
+
"outputs": [],
|
1105 |
+
"source": [
|
1106 |
+
"classifier = pipeline(\"text-classification\", \n",
|
1107 |
+
" model, \n",
|
1108 |
+
" tokenizer=tokenizer, \n",
|
1109 |
+
" device=\"cuda:0\")"
|
1110 |
+
]
|
1111 |
+
},
|
1112 |
+
{
|
1113 |
+
"cell_type": "markdown",
|
1114 |
+
"metadata": {},
|
1115 |
+
"source": [
|
1116 |
+
"We use the same tokenizer used for training to tokenize/encode the validation set."
|
1117 |
+
]
|
1118 |
+
},
|
1119 |
+
{
|
1120 |
+
"cell_type": "code",
|
1121 |
+
"execution_count": null,
|
1122 |
+
"metadata": {
|
1123 |
+
"datalore": {
|
1124 |
+
"hide_input_from_viewers": true,
|
1125 |
+
"hide_output_from_viewers": true,
|
1126 |
+
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
|
1127 |
+
"type": "CODE"
|
1128 |
+
},
|
1129 |
+
"gather": {
|
1130 |
+
"logged": 1706411523285
|
1131 |
+
}
|
1132 |
+
},
|
1133 |
+
"outputs": [],
|
1134 |
+
"source": [
|
1135 |
+
"test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
|
1136 |
+
" max_length=None, \n",
|
1137 |
+
" padding='max_length', \n",
|
1138 |
+
" return_token_type_ids=True, \n",
|
1139 |
+
" truncation=True)"
|
1140 |
+
]
|
1141 |
+
},
|
1142 |
+
{
|
1143 |
+
"cell_type": "markdown",
|
1144 |
+
"metadata": {},
|
1145 |
+
"source": [
|
1146 |
+
"Once we've made the data loadable by putting it into a `DataLoader`, we "
|
1147 |
+
]
|
1148 |
+
},
|
1149 |
+
{
|
1150 |
+
"cell_type": "code",
|
1151 |
+
"execution_count": null,
|
1152 |
+
"metadata": {
|
1153 |
+
"datalore": {
|
1154 |
+
"hide_input_from_viewers": true,
|
1155 |
+
"hide_output_from_viewers": true,
|
1156 |
+
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
|
1157 |
+
"type": "CODE"
|
1158 |
+
},
|
1159 |
+
"gather": {
|
1160 |
+
"logged": 1706411543379
|
1161 |
+
}
|
1162 |
+
},
|
1163 |
+
"outputs": [],
|
1164 |
+
"source": [
|
1165 |
+
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
|
1166 |
+
" torch.tensor(test_encodings['attention_mask']), \n",
|
1167 |
+
" torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
|
1168 |
+
" torch.tensor(test_encodings['token_type_ids']))\n",
|
1169 |
+
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
|
1170 |
+
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
|
1171 |
+
" batch_size=BATCH_SIZE)"
|
1172 |
+
]
|
1173 |
+
},
|
1174 |
+
{
|
1175 |
+
"cell_type": "code",
|
1176 |
+
"execution_count": null,
|
1177 |
+
"metadata": {
|
1178 |
+
"datalore": {
|
1179 |
+
"hide_input_from_viewers": true,
|
1180 |
+
"hide_output_from_viewers": true,
|
1181 |
+
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
|
1182 |
+
"type": "CODE"
|
1183 |
+
},
|
1184 |
+
"gather": {
|
1185 |
+
"logged": 1706411587843
|
1186 |
+
}
|
1187 |
+
},
|
1188 |
+
"outputs": [],
|
1189 |
+
"source": [
|
1190 |
+
"model.eval()\n",
|
1191 |
+
"\n",
|
1192 |
+
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
|
1193 |
+
"\n",
|
1194 |
+
"for i, batch in enumerate(test_dataloader):\n",
|
1195 |
+
" batch = tuple(t.to(device) for t in batch)\n",
|
1196 |
+
" \n",
|
1197 |
+
" # Unpack the inputs from our dataloader\n",
|
1198 |
+
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
|
1199 |
+
" \n",
|
1200 |
+
" with torch.no_grad():\n",
|
1201 |
+
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
|
1202 |
+
" b_logit_pred = outs[0]\n",
|
1203 |
+
" pred_label = torch.sigmoid(b_logit_pred)\n",
|
1204 |
+
"\n",
|
1205 |
+
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
|
1206 |
+
" pred_label = pred_label.to('cpu').numpy()\n",
|
1207 |
+
" b_labels = b_labels.to('cpu').numpy()\n",
|
1208 |
+
"\n",
|
1209 |
+
" tokenized_texts.append(b_input_ids)\n",
|
1210 |
+
" logit_preds.append(b_logit_pred)\n",
|
1211 |
+
" true_labels.append(b_labels)\n",
|
1212 |
+
" pred_labels.append(pred_label)\n",
|
1213 |
+
"\n",
|
1214 |
+
"# Flatten outputs\n",
|
1215 |
+
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
|
1216 |
+
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
|
1217 |
+
"true_labels = [item for sublist in true_labels for item in sublist]\n",
|
1218 |
+
"\n",
|
1219 |
+
"# Converting flattened binary values to boolean values\n",
|
1220 |
+
"true_bools = [tl == 1 for tl in true_labels]\n",
|
1221 |
+
"pred_bools = [pl > 0.50 for pl in pred_labels] "
|
1222 |
+
]
|
1223 |
+
},
|
1224 |
+
{
|
1225 |
+
"cell_type": "markdown",
|
1226 |
+
"metadata": {},
|
1227 |
+
"source": [
|
1228 |
+
"We create a classification report:"
|
1229 |
+
]
|
1230 |
+
},
|
1231 |
+
{
|
1232 |
+
"cell_type": "code",
|
1233 |
+
"execution_count": null,
|
1234 |
+
"metadata": {
|
1235 |
+
"datalore": {
|
1236 |
+
"hide_input_from_viewers": true,
|
1237 |
+
"hide_output_from_viewers": true,
|
1238 |
+
"node_id": "eBprrgF086mznPbPVBpOLS",
|
1239 |
+
"type": "CODE"
|
1240 |
+
},
|
1241 |
+
"gather": {
|
1242 |
+
"logged": 1706411588249
|
1243 |
+
}
|
1244 |
+
},
|
1245 |
+
"outputs": [],
|
1246 |
+
"source": [
|
1247 |
+
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
|
1248 |
+
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
|
1249 |
+
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
|
1250 |
+
"print(clf_report)"
|
1251 |
+
]
|
1252 |
+
},
|
1253 |
+
{
|
1254 |
+
"cell_type": "markdown",
|
1255 |
+
"metadata": {},
|
1256 |
+
"source": [
|
1257 |
+
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
|
1258 |
+
]
|
1259 |
+
},
|
1260 |
+
{
|
1261 |
+
"cell_type": "code",
|
1262 |
+
"execution_count": null,
|
1263 |
+
"metadata": {
|
1264 |
+
"datalore": {
|
1265 |
+
"hide_input_from_viewers": true,
|
1266 |
+
"hide_output_from_viewers": true,
|
1267 |
+
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
|
1268 |
+
"type": "CODE"
|
1269 |
+
},
|
1270 |
+
"gather": {
|
1271 |
+
"logged": 1706411588638
|
1272 |
+
}
|
1273 |
+
},
|
1274 |
+
"outputs": [],
|
1275 |
+
"source": [
|
1276 |
+
"# Creating a map of class names from class numbers\n",
|
1277 |
+
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
|
1278 |
+
]
|
1279 |
+
},
|
1280 |
+
{
|
1281 |
+
"cell_type": "code",
|
1282 |
+
"execution_count": null,
|
1283 |
+
"metadata": {
|
1284 |
+
"datalore": {
|
1285 |
+
"hide_input_from_viewers": true,
|
1286 |
+
"hide_output_from_viewers": true,
|
1287 |
+
"node_id": "jH0S35dDteUch01sa6me6e",
|
1288 |
+
"type": "CODE"
|
1289 |
+
},
|
1290 |
+
"gather": {
|
1291 |
+
"logged": 1706411589004
|
1292 |
+
}
|
1293 |
+
},
|
1294 |
+
"outputs": [],
|
1295 |
+
"source": [
|
1296 |
+
"true_label_idxs, pred_label_idxs = [], []\n",
|
1297 |
+
"\n",
|
1298 |
+
"for vals in true_bools:\n",
|
1299 |
+
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
|
1300 |
+
"for vals in pred_bools:\n",
|
1301 |
+
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
|
1302 |
+
]
|
1303 |
+
},
|
1304 |
+
{
|
1305 |
+
"cell_type": "code",
|
1306 |
+
"execution_count": null,
|
1307 |
+
"metadata": {
|
1308 |
+
"datalore": {
|
1309 |
+
"hide_input_from_viewers": true,
|
1310 |
+
"hide_output_from_viewers": true,
|
1311 |
+
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
|
1312 |
+
"type": "CODE"
|
1313 |
+
},
|
1314 |
+
"gather": {
|
1315 |
+
"logged": 1706411589301
|
1316 |
+
}
|
1317 |
+
},
|
1318 |
+
"outputs": [],
|
1319 |
+
"source": [
|
1320 |
+
"true_label_texts, pred_label_texts = [], []\n",
|
1321 |
+
"\n",
|
1322 |
+
"for vals in true_label_idxs:\n",
|
1323 |
+
" if vals:\n",
|
1324 |
+
" true_label_texts.append([idx2label[val] for val in vals])\n",
|
1325 |
+
" else:\n",
|
1326 |
+
" true_label_texts.append(vals)\n",
|
1327 |
+
"\n",
|
1328 |
+
"for vals in pred_label_idxs:\n",
|
1329 |
+
" if vals:\n",
|
1330 |
+
" pred_label_texts.append([idx2label[val] for val in vals])\n",
|
1331 |
+
" else:\n",
|
1332 |
+
" pred_label_texts.append(vals)"
|
1333 |
+
]
|
1334 |
+
},
|
1335 |
+
{
|
1336 |
+
"cell_type": "code",
|
1337 |
+
"execution_count": null,
|
1338 |
+
"metadata": {
|
1339 |
+
"datalore": {
|
1340 |
+
"hide_input_from_viewers": true,
|
1341 |
+
"hide_output_from_viewers": true,
|
1342 |
+
"node_id": "SxUmVHfQISEeptg1SawOmB",
|
1343 |
+
"type": "CODE"
|
1344 |
+
},
|
1345 |
+
"gather": {
|
1346 |
+
"logged": 1706411591952
|
1347 |
+
}
|
1348 |
+
},
|
1349 |
+
"outputs": [],
|
1350 |
+
"source": [
|
1351 |
+
"symptom_texts = [tokenizer.decode(text,\n",
|
1352 |
+
" skip_special_tokens=True,\n",
|
1353 |
+
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
|
1354 |
+
]
|
1355 |
+
},
|
1356 |
+
{
|
1357 |
+
"cell_type": "code",
|
1358 |
+
"execution_count": null,
|
1359 |
+
"metadata": {
|
1360 |
+
"datalore": {
|
1361 |
+
"hide_input_from_viewers": true,
|
1362 |
+
"hide_output_from_viewers": true,
|
1363 |
+
"node_id": "BxFNigNGRLTOqraI55BPSH",
|
1364 |
+
"type": "CODE"
|
1365 |
+
},
|
1366 |
+
"gather": {
|
1367 |
+
"logged": 1706411592512
|
1368 |
+
}
|
1369 |
+
},
|
1370 |
+
"outputs": [],
|
1371 |
+
"source": [
|
1372 |
+
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
|
1373 |
+
" 'true_labels': true_label_texts, \n",
|
1374 |
+
" 'pred_labels':pred_label_texts})\n",
|
1375 |
+
"comparisons_df.to_csv('comparisons.csv')\n",
|
1376 |
+
"comparisons_df"
|
1377 |
+
]
|
1378 |
+
}
|
1379 |
+
],
|
1380 |
+
"metadata": {
|
1381 |
+
"datalore": {
|
1382 |
+
"base_environment": "default",
|
1383 |
+
"computation_mode": "JUPYTER",
|
1384 |
+
"package_manager": "pip",
|
1385 |
+
"packages": [
|
1386 |
+
{
|
1387 |
+
"name": "datasets",
|
1388 |
+
"source": "PIP",
|
1389 |
+
"version": "2.16.1"
|
1390 |
+
},
|
1391 |
+
{
|
1392 |
+
"name": "torch",
|
1393 |
+
"source": "PIP",
|
1394 |
+
"version": "2.1.2"
|
1395 |
+
},
|
1396 |
+
{
|
1397 |
+
"name": "accelerate",
|
1398 |
+
"source": "PIP",
|
1399 |
+
"version": "0.26.1"
|
1400 |
+
}
|
1401 |
+
],
|
1402 |
+
"report_row_ids": [
|
1403 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
1404 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
1405 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
1406 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
1407 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
1408 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
1409 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
1410 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
1411 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
1412 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
1413 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
1414 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
1415 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
1416 |
+
],
|
1417 |
+
"version": 3
|
1418 |
+
},
|
1419 |
+
"kernelspec": {
|
1420 |
+
"display_name": "Python 3.8 - Pytorch and Tensorflow",
|
1421 |
+
"language": "python",
|
1422 |
+
"name": "python38-azureml-pt-tf"
|
1423 |
+
},
|
1424 |
+
"language_info": {
|
1425 |
+
"codemirror_mode": {
|
1426 |
+
"name": "ipython",
|
1427 |
+
"version": 3
|
1428 |
+
},
|
1429 |
+
"file_extension": ".py",
|
1430 |
+
"mimetype": "text/x-python",
|
1431 |
+
"name": "python",
|
1432 |
+
"nbconvert_exporter": "python",
|
1433 |
+
"pygments_lexer": "ipython3",
|
1434 |
+
"version": "3.8.5"
|
1435 |
+
},
|
1436 |
+
"microsoft": {
|
1437 |
+
"host": {
|
1438 |
+
"AzureML": {
|
1439 |
+
"notebookHasBeenCompleted": true
|
1440 |
+
}
|
1441 |
+
},
|
1442 |
+
"ms_spell_check": {
|
1443 |
+
"ms_spell_check_language": "en"
|
1444 |
+
}
|
1445 |
+
},
|
1446 |
+
"nteract": {
|
1447 |
+
"version": "nteract-front-end@1.0.0"
|
1448 |
+
}
|
1449 |
+
},
|
1450 |
+
"nbformat": 4,
|
1451 |
+
"nbformat_minor": 4
|
1452 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-16-26-9Z.ipynb
ADDED
@@ -0,0 +1,1246 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {}
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"source": [
|
15 |
+
"# %pip install accelerate -U"
|
16 |
+
],
|
17 |
+
"outputs": [],
|
18 |
+
"execution_count": 1,
|
19 |
+
"metadata": {
|
20 |
+
"nteract": {
|
21 |
+
"transient": {
|
22 |
+
"deleting": false
|
23 |
+
}
|
24 |
+
},
|
25 |
+
"tags": []
|
26 |
+
}
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"source": [
|
31 |
+
"%pip install transformers datasets shap watermark wandb scikit-multilearn"
|
32 |
+
],
|
33 |
+
"outputs": [
|
34 |
+
{
|
35 |
+
"output_type": "stream",
|
36 |
+
"name": "stdout",
|
37 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nCollecting scikit-multilearn\n Downloading scikit_multilearn-0.2.0-py3-none-any.whl (89 kB)\n\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m89.4/89.4 kB\u001b[0m \u001b[31m9.8 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n\u001b[?25hRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nInstalling collected packages: scikit-multilearn\nSuccessfully installed scikit-multilearn-0.2.0\nNote: you may need to restart the kernel to use updated packages.\n"
|
38 |
+
}
|
39 |
+
],
|
40 |
+
"execution_count": 1,
|
41 |
+
"metadata": {
|
42 |
+
"nteract": {
|
43 |
+
"transient": {
|
44 |
+
"deleting": false
|
45 |
+
}
|
46 |
+
}
|
47 |
+
}
|
48 |
+
},
|
49 |
+
{
|
50 |
+
"cell_type": "code",
|
51 |
+
"source": [
|
52 |
+
"import pandas as pd\n",
|
53 |
+
"import numpy as np\n",
|
54 |
+
"import torch\n",
|
55 |
+
"import os\n",
|
56 |
+
"from typing import List\n",
|
57 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
58 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
|
59 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
60 |
+
"from pyarrow import Table\n",
|
61 |
+
"import shap\n",
|
62 |
+
"import wandb\n",
|
63 |
+
"from skmultilearn.problem_transform import LabelPowerset\n",
|
64 |
+
"\n",
|
65 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
66 |
+
"\n",
|
67 |
+
"%load_ext watermark"
|
68 |
+
],
|
69 |
+
"outputs": [
|
70 |
+
{
|
71 |
+
"output_type": "stream",
|
72 |
+
"name": "stderr",
|
73 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 15:09:42.856486: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 15:09:43.818179: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 15:09:43.818307: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 15:09:43.818321: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
74 |
+
}
|
75 |
+
],
|
76 |
+
"execution_count": 2,
|
77 |
+
"metadata": {
|
78 |
+
"datalore": {
|
79 |
+
"hide_input_from_viewers": false,
|
80 |
+
"hide_output_from_viewers": false,
|
81 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
82 |
+
"report_properties": {
|
83 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
84 |
+
},
|
85 |
+
"type": "CODE"
|
86 |
+
},
|
87 |
+
"gather": {
|
88 |
+
"logged": 1706454586481
|
89 |
+
},
|
90 |
+
"tags": []
|
91 |
+
}
|
92 |
+
},
|
93 |
+
{
|
94 |
+
"cell_type": "code",
|
95 |
+
"source": [
|
96 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
97 |
+
"\n",
|
98 |
+
"SEED: int = 42\n",
|
99 |
+
"\n",
|
100 |
+
"BATCH_SIZE: int = 16\n",
|
101 |
+
"EPOCHS: int = 3\n",
|
102 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
103 |
+
"\n",
|
104 |
+
"CLASS_NAMES: List[str] = [\"DIED\",\n",
|
105 |
+
" \"ER_VISIT\",\n",
|
106 |
+
" \"HOSPITAL\",\n",
|
107 |
+
" \"OFC_VISIT\",\n",
|
108 |
+
" #\"X_STAY\", # pruned\n",
|
109 |
+
" #\"DISABLE\", # pruned\n",
|
110 |
+
" #\"D_PRESENTED\" # pruned\n",
|
111 |
+
" ]\n",
|
112 |
+
"\n",
|
113 |
+
"\n",
|
114 |
+
"\n",
|
115 |
+
"\n",
|
116 |
+
"# WandB configuration\n",
|
117 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
|
118 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
119 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
120 |
+
],
|
121 |
+
"outputs": [],
|
122 |
+
"execution_count": 3,
|
123 |
+
"metadata": {
|
124 |
+
"collapsed": false,
|
125 |
+
"gather": {
|
126 |
+
"logged": 1706454586654
|
127 |
+
},
|
128 |
+
"jupyter": {
|
129 |
+
"outputs_hidden": false
|
130 |
+
}
|
131 |
+
}
|
132 |
+
},
|
133 |
+
{
|
134 |
+
"cell_type": "code",
|
135 |
+
"source": [
|
136 |
+
"%watermark --iversion"
|
137 |
+
],
|
138 |
+
"outputs": [
|
139 |
+
{
|
140 |
+
"output_type": "stream",
|
141 |
+
"name": "stdout",
|
142 |
+
"text": "shap : 0.44.1\nlogging: 0.5.1.2\npandas : 2.0.2\nnumpy : 1.23.5\ntorch : 1.12.0\nwandb : 0.16.2\nre : 2.2.1\n\n"
|
143 |
+
}
|
144 |
+
],
|
145 |
+
"execution_count": 4,
|
146 |
+
"metadata": {
|
147 |
+
"collapsed": false,
|
148 |
+
"jupyter": {
|
149 |
+
"outputs_hidden": false
|
150 |
+
}
|
151 |
+
}
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "code",
|
155 |
+
"source": [
|
156 |
+
"!nvidia-smi"
|
157 |
+
],
|
158 |
+
"outputs": [
|
159 |
+
{
|
160 |
+
"output_type": "stream",
|
161 |
+
"name": "stdout",
|
162 |
+
"text": "Sun Jan 28 15:09:47 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 30C P0 38W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 29C P0 38W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
163 |
+
}
|
164 |
+
],
|
165 |
+
"execution_count": 5,
|
166 |
+
"metadata": {
|
167 |
+
"datalore": {
|
168 |
+
"hide_input_from_viewers": true,
|
169 |
+
"hide_output_from_viewers": true,
|
170 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
171 |
+
"type": "CODE"
|
172 |
+
}
|
173 |
+
}
|
174 |
+
},
|
175 |
+
{
|
176 |
+
"cell_type": "markdown",
|
177 |
+
"source": [
|
178 |
+
"## Loading the data set"
|
179 |
+
],
|
180 |
+
"metadata": {
|
181 |
+
"datalore": {
|
182 |
+
"hide_input_from_viewers": false,
|
183 |
+
"hide_output_from_viewers": false,
|
184 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
185 |
+
"report_properties": {
|
186 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
187 |
+
},
|
188 |
+
"type": "MD"
|
189 |
+
}
|
190 |
+
}
|
191 |
+
},
|
192 |
+
{
|
193 |
+
"cell_type": "code",
|
194 |
+
"source": [
|
195 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
196 |
+
],
|
197 |
+
"outputs": [],
|
198 |
+
"execution_count": 7,
|
199 |
+
"metadata": {
|
200 |
+
"collapsed": false,
|
201 |
+
"gather": {
|
202 |
+
"logged": 1706449040507
|
203 |
+
},
|
204 |
+
"jupyter": {
|
205 |
+
"outputs_hidden": false
|
206 |
+
}
|
207 |
+
}
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"source": [
|
212 |
+
"dataset"
|
213 |
+
],
|
214 |
+
"outputs": [
|
215 |
+
{
|
216 |
+
"output_type": "execute_result",
|
217 |
+
"execution_count": 8,
|
218 |
+
"data": {
|
219 |
+
"text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 272238\n })\n})"
|
220 |
+
},
|
221 |
+
"metadata": {}
|
222 |
+
}
|
223 |
+
],
|
224 |
+
"execution_count": 8,
|
225 |
+
"metadata": {
|
226 |
+
"collapsed": false,
|
227 |
+
"gather": {
|
228 |
+
"logged": 1706449044205
|
229 |
+
},
|
230 |
+
"jupyter": {
|
231 |
+
"outputs_hidden": false,
|
232 |
+
"source_hidden": false
|
233 |
+
},
|
234 |
+
"nteract": {
|
235 |
+
"transient": {
|
236 |
+
"deleting": false
|
237 |
+
}
|
238 |
+
}
|
239 |
+
}
|
240 |
+
},
|
241 |
+
{
|
242 |
+
"cell_type": "code",
|
243 |
+
"source": [
|
244 |
+
"SUBSAMPLING: float = 0.1"
|
245 |
+
],
|
246 |
+
"outputs": [],
|
247 |
+
"execution_count": 9,
|
248 |
+
"metadata": {}
|
249 |
+
},
|
250 |
+
{
|
251 |
+
"cell_type": "code",
|
252 |
+
"source": [
|
253 |
+
"def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n",
|
254 |
+
" res = DatasetDict()\n",
|
255 |
+
"\n",
|
256 |
+
" res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n",
|
257 |
+
" res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n",
|
258 |
+
" res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n",
|
259 |
+
" \n",
|
260 |
+
" return res"
|
261 |
+
],
|
262 |
+
"outputs": [],
|
263 |
+
"execution_count": 10,
|
264 |
+
"metadata": {
|
265 |
+
"collapsed": false,
|
266 |
+
"gather": {
|
267 |
+
"logged": 1706449378281
|
268 |
+
},
|
269 |
+
"jupyter": {
|
270 |
+
"outputs_hidden": false,
|
271 |
+
"source_hidden": false
|
272 |
+
},
|
273 |
+
"nteract": {
|
274 |
+
"transient": {
|
275 |
+
"deleting": false
|
276 |
+
}
|
277 |
+
}
|
278 |
+
}
|
279 |
+
},
|
280 |
+
{
|
281 |
+
"cell_type": "code",
|
282 |
+
"source": [
|
283 |
+
"dataset = minisample(dataset, SUBSAMPLING)"
|
284 |
+
],
|
285 |
+
"outputs": [],
|
286 |
+
"execution_count": 11,
|
287 |
+
"metadata": {
|
288 |
+
"collapsed": false,
|
289 |
+
"gather": {
|
290 |
+
"logged": 1706449384162
|
291 |
+
},
|
292 |
+
"jupyter": {
|
293 |
+
"outputs_hidden": false,
|
294 |
+
"source_hidden": false
|
295 |
+
},
|
296 |
+
"nteract": {
|
297 |
+
"transient": {
|
298 |
+
"deleting": false
|
299 |
+
}
|
300 |
+
}
|
301 |
+
}
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"source": [
|
306 |
+
"dataset"
|
307 |
+
],
|
308 |
+
"outputs": [
|
309 |
+
{
|
310 |
+
"output_type": "execute_result",
|
311 |
+
"execution_count": 12,
|
312 |
+
"data": {
|
313 |
+
"text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 127044\n })\n test: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 27224\n })\n val: Dataset({\n features: ['id', 'text', 'labels'],\n num_rows: 27224\n })\n})"
|
314 |
+
},
|
315 |
+
"metadata": {}
|
316 |
+
}
|
317 |
+
],
|
318 |
+
"execution_count": 12,
|
319 |
+
"metadata": {
|
320 |
+
"collapsed": false,
|
321 |
+
"gather": {
|
322 |
+
"logged": 1706449387981
|
323 |
+
},
|
324 |
+
"jupyter": {
|
325 |
+
"outputs_hidden": false,
|
326 |
+
"source_hidden": false
|
327 |
+
},
|
328 |
+
"nteract": {
|
329 |
+
"transient": {
|
330 |
+
"deleting": false
|
331 |
+
}
|
332 |
+
}
|
333 |
+
}
|
334 |
+
},
|
335 |
+
{
|
336 |
+
"cell_type": "markdown",
|
337 |
+
"source": [
|
338 |
+
"We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
|
339 |
+
],
|
340 |
+
"metadata": {
|
341 |
+
"nteract": {
|
342 |
+
"transient": {
|
343 |
+
"deleting": false
|
344 |
+
}
|
345 |
+
}
|
346 |
+
}
|
347 |
+
},
|
348 |
+
{
|
349 |
+
"cell_type": "code",
|
350 |
+
"source": [
|
351 |
+
"ds = DatasetDict()\n",
|
352 |
+
"\n",
|
353 |
+
"for i in [\"test\", \"train\", \"val\"]:\n",
|
354 |
+
" tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
|
355 |
+
" ds[i] = Dataset(tab)\n",
|
356 |
+
"\n",
|
357 |
+
"dataset = ds"
|
358 |
+
],
|
359 |
+
"outputs": [],
|
360 |
+
"execution_count": 13,
|
361 |
+
"metadata": {
|
362 |
+
"collapsed": false,
|
363 |
+
"gather": {
|
364 |
+
"logged": 1706449443055
|
365 |
+
},
|
366 |
+
"jupyter": {
|
367 |
+
"outputs_hidden": false,
|
368 |
+
"source_hidden": false
|
369 |
+
},
|
370 |
+
"nteract": {
|
371 |
+
"transient": {
|
372 |
+
"deleting": false
|
373 |
+
}
|
374 |
+
}
|
375 |
+
}
|
376 |
+
},
|
377 |
+
{
|
378 |
+
"cell_type": "markdown",
|
379 |
+
"source": [
|
380 |
+
"### Tokenisation and encoding"
|
381 |
+
],
|
382 |
+
"metadata": {}
|
383 |
+
},
|
384 |
+
{
|
385 |
+
"cell_type": "code",
|
386 |
+
"source": [
|
387 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
|
388 |
+
],
|
389 |
+
"outputs": [],
|
390 |
+
"execution_count": 14,
|
391 |
+
"metadata": {
|
392 |
+
"datalore": {
|
393 |
+
"hide_input_from_viewers": true,
|
394 |
+
"hide_output_from_viewers": true,
|
395 |
+
"node_id": "I7n646PIscsUZRoHu6m7zm",
|
396 |
+
"type": "CODE"
|
397 |
+
},
|
398 |
+
"gather": {
|
399 |
+
"logged": 1706449638377
|
400 |
+
}
|
401 |
+
}
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "code",
|
405 |
+
"source": [
|
406 |
+
"def tokenize_and_encode(examples):\n",
|
407 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
408 |
+
],
|
409 |
+
"outputs": [],
|
410 |
+
"execution_count": 15,
|
411 |
+
"metadata": {
|
412 |
+
"datalore": {
|
413 |
+
"hide_input_from_viewers": true,
|
414 |
+
"hide_output_from_viewers": true,
|
415 |
+
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
|
416 |
+
"type": "CODE"
|
417 |
+
},
|
418 |
+
"gather": {
|
419 |
+
"logged": 1706449642580
|
420 |
+
}
|
421 |
+
}
|
422 |
+
},
|
423 |
+
{
|
424 |
+
"cell_type": "code",
|
425 |
+
"source": [
|
426 |
+
"cols = dataset[\"train\"].column_names\n",
|
427 |
+
"cols.remove(\"labels\")\n",
|
428 |
+
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
429 |
+
],
|
430 |
+
"outputs": [
|
431 |
+
{
|
432 |
+
"output_type": "stream",
|
433 |
+
"name": "stderr",
|
434 |
+
"text": "Map: 100%|██████████| 27224/27224 [00:10<00:00, 2638.52 examples/s]\nMap: 100%|██████████| 127044/127044 [00:48<00:00, 2633.40 examples/s]\nMap: 100%|██████████| 27224/27224 [00:10<00:00, 2613.19 examples/s]\n"
|
435 |
+
}
|
436 |
+
],
|
437 |
+
"execution_count": 16,
|
438 |
+
"metadata": {
|
439 |
+
"datalore": {
|
440 |
+
"hide_input_from_viewers": true,
|
441 |
+
"hide_output_from_viewers": true,
|
442 |
+
"node_id": "slHeNysZOX9uWS9PB7jFDb",
|
443 |
+
"type": "CODE"
|
444 |
+
},
|
445 |
+
"gather": {
|
446 |
+
"logged": 1706449721161
|
447 |
+
}
|
448 |
+
}
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "markdown",
|
452 |
+
"source": [
|
453 |
+
"### Training"
|
454 |
+
],
|
455 |
+
"metadata": {}
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"cell_type": "code",
|
459 |
+
"source": [
|
460 |
+
"class MultiLabelTrainer(Trainer):\n",
|
461 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
462 |
+
" labels = inputs.pop(\"labels\")\n",
|
463 |
+
" outputs = model(**inputs)\n",
|
464 |
+
" logits = outputs.logits\n",
|
465 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
466 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
467 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
468 |
+
" return (loss, outputs) if return_outputs else loss"
|
469 |
+
],
|
470 |
+
"outputs": [],
|
471 |
+
"execution_count": 17,
|
472 |
+
"metadata": {
|
473 |
+
"datalore": {
|
474 |
+
"hide_input_from_viewers": true,
|
475 |
+
"hide_output_from_viewers": true,
|
476 |
+
"node_id": "itXWkbDw9sqbkMuDP84QoT",
|
477 |
+
"type": "CODE"
|
478 |
+
},
|
479 |
+
"gather": {
|
480 |
+
"logged": 1706449743072
|
481 |
+
}
|
482 |
+
}
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"cell_type": "code",
|
486 |
+
"source": [
|
487 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
|
488 |
+
],
|
489 |
+
"outputs": [
|
490 |
+
{
|
491 |
+
"output_type": "stream",
|
492 |
+
"name": "stderr",
|
493 |
+
"text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
494 |
+
}
|
495 |
+
],
|
496 |
+
"execution_count": 18,
|
497 |
+
"metadata": {
|
498 |
+
"datalore": {
|
499 |
+
"hide_input_from_viewers": true,
|
500 |
+
"hide_output_from_viewers": true,
|
501 |
+
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
|
502 |
+
"type": "CODE"
|
503 |
+
},
|
504 |
+
"gather": {
|
505 |
+
"logged": 1706449761205
|
506 |
+
}
|
507 |
+
}
|
508 |
+
},
|
509 |
+
{
|
510 |
+
"cell_type": "code",
|
511 |
+
"source": [
|
512 |
+
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
|
513 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
514 |
+
" y_true = torch.from_numpy(y_true)\n",
|
515 |
+
"\n",
|
516 |
+
" if sigmoid:\n",
|
517 |
+
" y_pred = y_pred.sigmoid()\n",
|
518 |
+
"\n",
|
519 |
+
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
|
520 |
+
],
|
521 |
+
"outputs": [],
|
522 |
+
"execution_count": 19,
|
523 |
+
"metadata": {
|
524 |
+
"datalore": {
|
525 |
+
"hide_input_from_viewers": true,
|
526 |
+
"hide_output_from_viewers": true,
|
527 |
+
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
|
528 |
+
"type": "CODE"
|
529 |
+
},
|
530 |
+
"gather": {
|
531 |
+
"logged": 1706449761541
|
532 |
+
}
|
533 |
+
}
|
534 |
+
},
|
535 |
+
{
|
536 |
+
"cell_type": "code",
|
537 |
+
"source": [
|
538 |
+
"def compute_metrics(eval_pred):\n",
|
539 |
+
" predictions, labels = eval_pred\n",
|
540 |
+
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
|
541 |
+
],
|
542 |
+
"outputs": [],
|
543 |
+
"execution_count": 20,
|
544 |
+
"metadata": {
|
545 |
+
"datalore": {
|
546 |
+
"hide_input_from_viewers": true,
|
547 |
+
"hide_output_from_viewers": true,
|
548 |
+
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
|
549 |
+
"type": "CODE"
|
550 |
+
},
|
551 |
+
"gather": {
|
552 |
+
"logged": 1706449761720
|
553 |
+
}
|
554 |
+
}
|
555 |
+
},
|
556 |
+
{
|
557 |
+
"cell_type": "code",
|
558 |
+
"source": [
|
559 |
+
"args = TrainingArguments(\n",
|
560 |
+
" output_dir=\"vaers\",\n",
|
561 |
+
" evaluation_strategy=\"epoch\",\n",
|
562 |
+
" learning_rate=2e-5,\n",
|
563 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
564 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
565 |
+
" num_train_epochs=EPOCHS,\n",
|
566 |
+
" weight_decay=.01,\n",
|
567 |
+
" logging_steps=1,\n",
|
568 |
+
" run_name=f\"daedra-training\",\n",
|
569 |
+
" report_to=[\"wandb\"]\n",
|
570 |
+
")"
|
571 |
+
],
|
572 |
+
"outputs": [],
|
573 |
+
"execution_count": 21,
|
574 |
+
"metadata": {
|
575 |
+
"datalore": {
|
576 |
+
"hide_input_from_viewers": true,
|
577 |
+
"hide_output_from_viewers": true,
|
578 |
+
"node_id": "1iPZOTKPwSkTgX5dORqT89",
|
579 |
+
"type": "CODE"
|
580 |
+
},
|
581 |
+
"gather": {
|
582 |
+
"logged": 1706449761893
|
583 |
+
}
|
584 |
+
}
|
585 |
+
},
|
586 |
+
{
|
587 |
+
"cell_type": "code",
|
588 |
+
"source": [
|
589 |
+
"multi_label_trainer = MultiLabelTrainer(\n",
|
590 |
+
" model, \n",
|
591 |
+
" args, \n",
|
592 |
+
" train_dataset=ds_enc[\"train\"], \n",
|
593 |
+
" eval_dataset=ds_enc[\"test\"], \n",
|
594 |
+
" compute_metrics=compute_metrics, \n",
|
595 |
+
" tokenizer=tokenizer\n",
|
596 |
+
")"
|
597 |
+
],
|
598 |
+
"outputs": [],
|
599 |
+
"execution_count": 22,
|
600 |
+
"metadata": {
|
601 |
+
"datalore": {
|
602 |
+
"hide_input_from_viewers": true,
|
603 |
+
"hide_output_from_viewers": true,
|
604 |
+
"node_id": "bnRkNvRYltLun6gCEgL7v0",
|
605 |
+
"type": "CODE"
|
606 |
+
},
|
607 |
+
"gather": {
|
608 |
+
"logged": 1706449769103
|
609 |
+
}
|
610 |
+
}
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"cell_type": "code",
|
614 |
+
"source": [
|
615 |
+
"if SUBSAMPLING != 1.0:\n",
|
616 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
617 |
+
"else:\n",
|
618 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
619 |
+
" \n",
|
620 |
+
"wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n",
|
621 |
+
"\n",
|
622 |
+
"multi_label_trainer.evaluate()\n",
|
623 |
+
"wandb.finish()"
|
624 |
+
],
|
625 |
+
"outputs": [
|
626 |
+
{
|
627 |
+
"output_type": "stream",
|
628 |
+
"name": "stderr",
|
629 |
+
"text": "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
630 |
+
},
|
631 |
+
{
|
632 |
+
"output_type": "display_data",
|
633 |
+
"data": {
|
634 |
+
"text/html": "Tracking run with wandb version 0.16.2",
|
635 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
636 |
+
},
|
637 |
+
"metadata": {}
|
638 |
+
},
|
639 |
+
{
|
640 |
+
"output_type": "display_data",
|
641 |
+
"data": {
|
642 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141956-9lniqjvz</code>",
|
643 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
644 |
+
},
|
645 |
+
"metadata": {}
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"output_type": "display_data",
|
649 |
+
"data": {
|
650 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>",
|
651 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
652 |
+
},
|
653 |
+
"metadata": {}
|
654 |
+
},
|
655 |
+
{
|
656 |
+
"output_type": "display_data",
|
657 |
+
"data": {
|
658 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>",
|
659 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
660 |
+
},
|
661 |
+
"metadata": {}
|
662 |
+
},
|
663 |
+
{
|
664 |
+
"output_type": "display_data",
|
665 |
+
"data": {
|
666 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a>",
|
667 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
668 |
+
},
|
669 |
+
"metadata": {}
|
670 |
+
},
|
671 |
+
{
|
672 |
+
"output_type": "display_data",
|
673 |
+
"data": {
|
674 |
+
"text/html": "Finishing last run (ID:9lniqjvz) before initializing another...",
|
675 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
676 |
+
},
|
677 |
+
"metadata": {}
|
678 |
+
},
|
679 |
+
{
|
680 |
+
"output_type": "display_data",
|
681 |
+
"data": {
|
682 |
+
"text/html": " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/9lniqjvz</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)",
|
683 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
684 |
+
},
|
685 |
+
"metadata": {}
|
686 |
+
},
|
687 |
+
{
|
688 |
+
"output_type": "display_data",
|
689 |
+
"data": {
|
690 |
+
"text/html": "Find logs at: <code>./wandb/run-20240128_141956-9lniqjvz/logs</code>",
|
691 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
692 |
+
},
|
693 |
+
"metadata": {}
|
694 |
+
},
|
695 |
+
{
|
696 |
+
"output_type": "display_data",
|
697 |
+
"data": {
|
698 |
+
"text/html": "Successfully finished last run (ID:9lniqjvz). Initializing new run:<br/>",
|
699 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
700 |
+
},
|
701 |
+
"metadata": {}
|
702 |
+
},
|
703 |
+
{
|
704 |
+
"output_type": "display_data",
|
705 |
+
"data": {
|
706 |
+
"text/html": "Tracking run with wandb version 0.16.2",
|
707 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
708 |
+
},
|
709 |
+
"metadata": {}
|
710 |
+
},
|
711 |
+
{
|
712 |
+
"output_type": "display_data",
|
713 |
+
"data": {
|
714 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141958-5idmkcie</code>",
|
715 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
716 |
+
},
|
717 |
+
"metadata": {}
|
718 |
+
},
|
719 |
+
{
|
720 |
+
"output_type": "display_data",
|
721 |
+
"data": {
|
722 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>",
|
723 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
724 |
+
},
|
725 |
+
"metadata": {}
|
726 |
+
},
|
727 |
+
{
|
728 |
+
"output_type": "display_data",
|
729 |
+
"data": {
|
730 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>",
|
731 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
732 |
+
},
|
733 |
+
"metadata": {}
|
734 |
+
},
|
735 |
+
{
|
736 |
+
"output_type": "display_data",
|
737 |
+
"data": {
|
738 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a>",
|
739 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
740 |
+
},
|
741 |
+
"metadata": {}
|
742 |
+
},
|
743 |
+
{
|
744 |
+
"output_type": "display_data",
|
745 |
+
"data": {
|
746 |
+
"text/html": "\n <div>\n \n <progress value='1003' max='851' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [851/851 26:26]\n </div>\n ",
|
747 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
748 |
+
},
|
749 |
+
"metadata": {}
|
750 |
+
},
|
751 |
+
{
|
752 |
+
"output_type": "display_data",
|
753 |
+
"data": {
|
754 |
+
"text/html": "<style>\n table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n </style>\n<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>▁</td></tr><tr><td>eval/loss</td><td>▁</td></tr><tr><td>eval/runtime</td><td>▁</td></tr><tr><td>eval/samples_per_second</td><td>▁</td></tr><tr><td>eval/steps_per_second</td><td>▁</td></tr><tr><td>train/global_step</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>0.55198</td></tr><tr><td>eval/loss</td><td>0.68442</td></tr><tr><td>eval/runtime</td><td>105.0436</td></tr><tr><td>eval/samples_per_second</td><td>259.168</td></tr><tr><td>eval/steps_per_second</td><td>8.101</td></tr><tr><td>train/global_step</td><td>0</td></tr></table><br/></div></div>",
|
755 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
756 |
+
},
|
757 |
+
"metadata": {}
|
758 |
+
},
|
759 |
+
{
|
760 |
+
"output_type": "display_data",
|
761 |
+
"data": {
|
762 |
+
"text/html": " View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/5idmkcie</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)",
|
763 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
764 |
+
},
|
765 |
+
"metadata": {}
|
766 |
+
},
|
767 |
+
{
|
768 |
+
"output_type": "display_data",
|
769 |
+
"data": {
|
770 |
+
"text/html": "Find logs at: <code>./wandb/run-20240128_141958-5idmkcie/logs</code>",
|
771 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
772 |
+
},
|
773 |
+
"metadata": {}
|
774 |
+
}
|
775 |
+
],
|
776 |
+
"execution_count": 23,
|
777 |
+
"metadata": {
|
778 |
+
"datalore": {
|
779 |
+
"hide_input_from_viewers": true,
|
780 |
+
"hide_output_from_viewers": true,
|
781 |
+
"node_id": "LO54PlDkWQdFrzV25FvduB",
|
782 |
+
"type": "CODE"
|
783 |
+
},
|
784 |
+
"gather": {
|
785 |
+
"logged": 1706449880674
|
786 |
+
}
|
787 |
+
}
|
788 |
+
},
|
789 |
+
{
|
790 |
+
"cell_type": "code",
|
791 |
+
"source": [
|
792 |
+
"if SUBSAMPLING != 1.0:\n",
|
793 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
794 |
+
"else:\n",
|
795 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
796 |
+
" \n",
|
797 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n",
|
798 |
+
"\n",
|
799 |
+
"multi_label_trainer.train()\n",
|
800 |
+
"wandb.finish()"
|
801 |
+
],
|
802 |
+
"outputs": [
|
803 |
+
{
|
804 |
+
"output_type": "display_data",
|
805 |
+
"data": {
|
806 |
+
"text/html": "Tracking run with wandb version 0.16.2",
|
807 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
808 |
+
},
|
809 |
+
"metadata": {}
|
810 |
+
},
|
811 |
+
{
|
812 |
+
"output_type": "display_data",
|
813 |
+
"data": {
|
814 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_142151-2mcc0ibc</code>",
|
815 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
816 |
+
},
|
817 |
+
"metadata": {}
|
818 |
+
},
|
819 |
+
{
|
820 |
+
"output_type": "display_data",
|
821 |
+
"data": {
|
822 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>",
|
823 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
824 |
+
},
|
825 |
+
"metadata": {}
|
826 |
+
},
|
827 |
+
{
|
828 |
+
"output_type": "display_data",
|
829 |
+
"data": {
|
830 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>",
|
831 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
832 |
+
},
|
833 |
+
"metadata": {}
|
834 |
+
},
|
835 |
+
{
|
836 |
+
"output_type": "display_data",
|
837 |
+
"data": {
|
838 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/2mcc0ibc</a>",
|
839 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
840 |
+
},
|
841 |
+
"metadata": {}
|
842 |
+
},
|
843 |
+
{
|
844 |
+
"output_type": "display_data",
|
845 |
+
"data": {
|
846 |
+
"text/html": "\n <div>\n \n <progress value='3972' max='11913' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 3972/11913 24:20 < 48:40, 2.72 it/s, Epoch 1/3]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>",
|
847 |
+
"text/plain": "<IPython.core.display.HTML object>"
|
848 |
+
},
|
849 |
+
"metadata": {}
|
850 |
+
},
|
851 |
+
{
|
852 |
+
"output_type": "stream",
|
853 |
+
"name": "stderr",
|
854 |
+
"text": "\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.6s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 22.7s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 14.0s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 15.2s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 14.0s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 12.4s\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 13.4s\n"
|
855 |
+
}
|
856 |
+
],
|
857 |
+
"execution_count": null,
|
858 |
+
"metadata": {
|
859 |
+
"datalore": {
|
860 |
+
"hide_input_from_viewers": true,
|
861 |
+
"hide_output_from_viewers": true,
|
862 |
+
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
|
863 |
+
"type": "CODE"
|
864 |
+
},
|
865 |
+
"gather": {
|
866 |
+
"logged": 1706449934637
|
867 |
+
}
|
868 |
+
}
|
869 |
+
},
|
870 |
+
{
|
871 |
+
"cell_type": "markdown",
|
872 |
+
"source": [
|
873 |
+
"### Evaluation"
|
874 |
+
],
|
875 |
+
"metadata": {}
|
876 |
+
},
|
877 |
+
{
|
878 |
+
"cell_type": "markdown",
|
879 |
+
"source": [
|
880 |
+
"We instantiate a classifier `pipeline` and push it to CUDA."
|
881 |
+
],
|
882 |
+
"metadata": {}
|
883 |
+
},
|
884 |
+
{
|
885 |
+
"cell_type": "code",
|
886 |
+
"source": [
|
887 |
+
"classifier = pipeline(\"text-classification\", \n",
|
888 |
+
" model, \n",
|
889 |
+
" tokenizer=tokenizer, \n",
|
890 |
+
" device=\"cuda:0\")"
|
891 |
+
],
|
892 |
+
"outputs": [],
|
893 |
+
"execution_count": null,
|
894 |
+
"metadata": {
|
895 |
+
"datalore": {
|
896 |
+
"hide_input_from_viewers": true,
|
897 |
+
"hide_output_from_viewers": true,
|
898 |
+
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
|
899 |
+
"type": "CODE"
|
900 |
+
},
|
901 |
+
"gather": {
|
902 |
+
"logged": 1706411459928
|
903 |
+
}
|
904 |
+
}
|
905 |
+
},
|
906 |
+
{
|
907 |
+
"cell_type": "markdown",
|
908 |
+
"source": [
|
909 |
+
"We use the same tokenizer used for training to tokenize/encode the validation set."
|
910 |
+
],
|
911 |
+
"metadata": {}
|
912 |
+
},
|
913 |
+
{
|
914 |
+
"cell_type": "code",
|
915 |
+
"source": [
|
916 |
+
"test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
|
917 |
+
" max_length=None, \n",
|
918 |
+
" padding='max_length', \n",
|
919 |
+
" return_token_type_ids=True, \n",
|
920 |
+
" truncation=True)"
|
921 |
+
],
|
922 |
+
"outputs": [],
|
923 |
+
"execution_count": null,
|
924 |
+
"metadata": {
|
925 |
+
"datalore": {
|
926 |
+
"hide_input_from_viewers": true,
|
927 |
+
"hide_output_from_viewers": true,
|
928 |
+
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
|
929 |
+
"type": "CODE"
|
930 |
+
},
|
931 |
+
"gather": {
|
932 |
+
"logged": 1706411523285
|
933 |
+
}
|
934 |
+
}
|
935 |
+
},
|
936 |
+
{
|
937 |
+
"cell_type": "markdown",
|
938 |
+
"source": [
|
939 |
+
"Once we've made the data loadable by putting it into a `DataLoader`, we "
|
940 |
+
],
|
941 |
+
"metadata": {}
|
942 |
+
},
|
943 |
+
{
|
944 |
+
"cell_type": "code",
|
945 |
+
"source": [
|
946 |
+
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
|
947 |
+
" torch.tensor(test_encodings['attention_mask']), \n",
|
948 |
+
" torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
|
949 |
+
" torch.tensor(test_encodings['token_type_ids']))\n",
|
950 |
+
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
|
951 |
+
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
|
952 |
+
" batch_size=BATCH_SIZE)"
|
953 |
+
],
|
954 |
+
"outputs": [],
|
955 |
+
"execution_count": null,
|
956 |
+
"metadata": {
|
957 |
+
"datalore": {
|
958 |
+
"hide_input_from_viewers": true,
|
959 |
+
"hide_output_from_viewers": true,
|
960 |
+
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
|
961 |
+
"type": "CODE"
|
962 |
+
},
|
963 |
+
"gather": {
|
964 |
+
"logged": 1706411543379
|
965 |
+
}
|
966 |
+
}
|
967 |
+
},
|
968 |
+
{
|
969 |
+
"cell_type": "code",
|
970 |
+
"source": [
|
971 |
+
"model.eval()\n",
|
972 |
+
"\n",
|
973 |
+
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
|
974 |
+
"\n",
|
975 |
+
"for i, batch in enumerate(test_dataloader):\n",
|
976 |
+
" batch = tuple(t.to(device) for t in batch)\n",
|
977 |
+
" \n",
|
978 |
+
" # Unpack the inputs from our dataloader\n",
|
979 |
+
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
|
980 |
+
" \n",
|
981 |
+
" with torch.no_grad():\n",
|
982 |
+
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
|
983 |
+
" b_logit_pred = outs[0]\n",
|
984 |
+
" pred_label = torch.sigmoid(b_logit_pred)\n",
|
985 |
+
"\n",
|
986 |
+
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
|
987 |
+
" pred_label = pred_label.to('cpu').numpy()\n",
|
988 |
+
" b_labels = b_labels.to('cpu').numpy()\n",
|
989 |
+
"\n",
|
990 |
+
" tokenized_texts.append(b_input_ids)\n",
|
991 |
+
" logit_preds.append(b_logit_pred)\n",
|
992 |
+
" true_labels.append(b_labels)\n",
|
993 |
+
" pred_labels.append(pred_label)\n",
|
994 |
+
"\n",
|
995 |
+
"# Flatten outputs\n",
|
996 |
+
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
|
997 |
+
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
|
998 |
+
"true_labels = [item for sublist in true_labels for item in sublist]\n",
|
999 |
+
"\n",
|
1000 |
+
"# Converting flattened binary values to boolean values\n",
|
1001 |
+
"true_bools = [tl == 1 for tl in true_labels]\n",
|
1002 |
+
"pred_bools = [pl > 0.50 for pl in pred_labels] "
|
1003 |
+
],
|
1004 |
+
"outputs": [],
|
1005 |
+
"execution_count": null,
|
1006 |
+
"metadata": {
|
1007 |
+
"datalore": {
|
1008 |
+
"hide_input_from_viewers": true,
|
1009 |
+
"hide_output_from_viewers": true,
|
1010 |
+
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
|
1011 |
+
"type": "CODE"
|
1012 |
+
},
|
1013 |
+
"gather": {
|
1014 |
+
"logged": 1706411587843
|
1015 |
+
}
|
1016 |
+
}
|
1017 |
+
},
|
1018 |
+
{
|
1019 |
+
"cell_type": "markdown",
|
1020 |
+
"source": [
|
1021 |
+
"We create a classification report:"
|
1022 |
+
],
|
1023 |
+
"metadata": {}
|
1024 |
+
},
|
1025 |
+
{
|
1026 |
+
"cell_type": "code",
|
1027 |
+
"source": [
|
1028 |
+
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
|
1029 |
+
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
|
1030 |
+
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
|
1031 |
+
"print(clf_report)"
|
1032 |
+
],
|
1033 |
+
"outputs": [],
|
1034 |
+
"execution_count": null,
|
1035 |
+
"metadata": {
|
1036 |
+
"datalore": {
|
1037 |
+
"hide_input_from_viewers": true,
|
1038 |
+
"hide_output_from_viewers": true,
|
1039 |
+
"node_id": "eBprrgF086mznPbPVBpOLS",
|
1040 |
+
"type": "CODE"
|
1041 |
+
},
|
1042 |
+
"gather": {
|
1043 |
+
"logged": 1706411588249
|
1044 |
+
}
|
1045 |
+
}
|
1046 |
+
},
|
1047 |
+
{
|
1048 |
+
"cell_type": "markdown",
|
1049 |
+
"source": [
|
1050 |
+
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
|
1051 |
+
],
|
1052 |
+
"metadata": {}
|
1053 |
+
},
|
1054 |
+
{
|
1055 |
+
"cell_type": "code",
|
1056 |
+
"source": [
|
1057 |
+
"# Creating a map of class names from class numbers\n",
|
1058 |
+
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
|
1059 |
+
],
|
1060 |
+
"outputs": [],
|
1061 |
+
"execution_count": null,
|
1062 |
+
"metadata": {
|
1063 |
+
"datalore": {
|
1064 |
+
"hide_input_from_viewers": true,
|
1065 |
+
"hide_output_from_viewers": true,
|
1066 |
+
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
|
1067 |
+
"type": "CODE"
|
1068 |
+
},
|
1069 |
+
"gather": {
|
1070 |
+
"logged": 1706411588638
|
1071 |
+
}
|
1072 |
+
}
|
1073 |
+
},
|
1074 |
+
{
|
1075 |
+
"cell_type": "code",
|
1076 |
+
"source": [
|
1077 |
+
"true_label_idxs, pred_label_idxs = [], []\n",
|
1078 |
+
"\n",
|
1079 |
+
"for vals in true_bools:\n",
|
1080 |
+
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
|
1081 |
+
"for vals in pred_bools:\n",
|
1082 |
+
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
|
1083 |
+
],
|
1084 |
+
"outputs": [],
|
1085 |
+
"execution_count": null,
|
1086 |
+
"metadata": {
|
1087 |
+
"datalore": {
|
1088 |
+
"hide_input_from_viewers": true,
|
1089 |
+
"hide_output_from_viewers": true,
|
1090 |
+
"node_id": "jH0S35dDteUch01sa6me6e",
|
1091 |
+
"type": "CODE"
|
1092 |
+
},
|
1093 |
+
"gather": {
|
1094 |
+
"logged": 1706411589004
|
1095 |
+
}
|
1096 |
+
}
|
1097 |
+
},
|
1098 |
+
{
|
1099 |
+
"cell_type": "code",
|
1100 |
+
"source": [
|
1101 |
+
"true_label_texts, pred_label_texts = [], []\n",
|
1102 |
+
"\n",
|
1103 |
+
"for vals in true_label_idxs:\n",
|
1104 |
+
" if vals:\n",
|
1105 |
+
" true_label_texts.append([idx2label[val] for val in vals])\n",
|
1106 |
+
" else:\n",
|
1107 |
+
" true_label_texts.append(vals)\n",
|
1108 |
+
"\n",
|
1109 |
+
"for vals in pred_label_idxs:\n",
|
1110 |
+
" if vals:\n",
|
1111 |
+
" pred_label_texts.append([idx2label[val] for val in vals])\n",
|
1112 |
+
" else:\n",
|
1113 |
+
" pred_label_texts.append(vals)"
|
1114 |
+
],
|
1115 |
+
"outputs": [],
|
1116 |
+
"execution_count": null,
|
1117 |
+
"metadata": {
|
1118 |
+
"datalore": {
|
1119 |
+
"hide_input_from_viewers": true,
|
1120 |
+
"hide_output_from_viewers": true,
|
1121 |
+
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
|
1122 |
+
"type": "CODE"
|
1123 |
+
},
|
1124 |
+
"gather": {
|
1125 |
+
"logged": 1706411589301
|
1126 |
+
}
|
1127 |
+
}
|
1128 |
+
},
|
1129 |
+
{
|
1130 |
+
"cell_type": "code",
|
1131 |
+
"source": [
|
1132 |
+
"symptom_texts = [tokenizer.decode(text,\n",
|
1133 |
+
" skip_special_tokens=True,\n",
|
1134 |
+
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
|
1135 |
+
],
|
1136 |
+
"outputs": [],
|
1137 |
+
"execution_count": null,
|
1138 |
+
"metadata": {
|
1139 |
+
"datalore": {
|
1140 |
+
"hide_input_from_viewers": true,
|
1141 |
+
"hide_output_from_viewers": true,
|
1142 |
+
"node_id": "SxUmVHfQISEeptg1SawOmB",
|
1143 |
+
"type": "CODE"
|
1144 |
+
},
|
1145 |
+
"gather": {
|
1146 |
+
"logged": 1706411591952
|
1147 |
+
}
|
1148 |
+
}
|
1149 |
+
},
|
1150 |
+
{
|
1151 |
+
"cell_type": "code",
|
1152 |
+
"source": [
|
1153 |
+
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
|
1154 |
+
" 'true_labels': true_label_texts, \n",
|
1155 |
+
" 'pred_labels':pred_label_texts})\n",
|
1156 |
+
"comparisons_df.to_csv('comparisons.csv')\n",
|
1157 |
+
"comparisons_df"
|
1158 |
+
],
|
1159 |
+
"outputs": [],
|
1160 |
+
"execution_count": null,
|
1161 |
+
"metadata": {
|
1162 |
+
"datalore": {
|
1163 |
+
"hide_input_from_viewers": true,
|
1164 |
+
"hide_output_from_viewers": true,
|
1165 |
+
"node_id": "BxFNigNGRLTOqraI55BPSH",
|
1166 |
+
"type": "CODE"
|
1167 |
+
},
|
1168 |
+
"gather": {
|
1169 |
+
"logged": 1706411592512
|
1170 |
+
}
|
1171 |
+
}
|
1172 |
+
}
|
1173 |
+
],
|
1174 |
+
"metadata": {
|
1175 |
+
"datalore": {
|
1176 |
+
"base_environment": "default",
|
1177 |
+
"computation_mode": "JUPYTER",
|
1178 |
+
"package_manager": "pip",
|
1179 |
+
"packages": [
|
1180 |
+
{
|
1181 |
+
"name": "datasets",
|
1182 |
+
"source": "PIP",
|
1183 |
+
"version": "2.16.1"
|
1184 |
+
},
|
1185 |
+
{
|
1186 |
+
"name": "torch",
|
1187 |
+
"source": "PIP",
|
1188 |
+
"version": "2.1.2"
|
1189 |
+
},
|
1190 |
+
{
|
1191 |
+
"name": "accelerate",
|
1192 |
+
"source": "PIP",
|
1193 |
+
"version": "0.26.1"
|
1194 |
+
}
|
1195 |
+
],
|
1196 |
+
"report_row_ids": [
|
1197 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
1198 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
1199 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
1200 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
1201 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
1202 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
1203 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
1204 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
1205 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
1206 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
1207 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
1208 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
1209 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
1210 |
+
],
|
1211 |
+
"version": 3
|
1212 |
+
},
|
1213 |
+
"kernelspec": {
|
1214 |
+
"display_name": "Python 3.8 - Pytorch and Tensorflow",
|
1215 |
+
"language": "python",
|
1216 |
+
"name": "python38-azureml-pt-tf"
|
1217 |
+
},
|
1218 |
+
"language_info": {
|
1219 |
+
"name": "python",
|
1220 |
+
"version": "3.8.5",
|
1221 |
+
"mimetype": "text/x-python",
|
1222 |
+
"codemirror_mode": {
|
1223 |
+
"name": "ipython",
|
1224 |
+
"version": 3
|
1225 |
+
},
|
1226 |
+
"pygments_lexer": "ipython3",
|
1227 |
+
"nbconvert_exporter": "python",
|
1228 |
+
"file_extension": ".py"
|
1229 |
+
},
|
1230 |
+
"microsoft": {
|
1231 |
+
"host": {
|
1232 |
+
"AzureML": {
|
1233 |
+
"notebookHasBeenCompleted": true
|
1234 |
+
}
|
1235 |
+
},
|
1236 |
+
"ms_spell_check": {
|
1237 |
+
"ms_spell_check_language": "en"
|
1238 |
+
}
|
1239 |
+
},
|
1240 |
+
"nteract": {
|
1241 |
+
"version": "nteract-front-end@1.0.0"
|
1242 |
+
}
|
1243 |
+
},
|
1244 |
+
"nbformat": 4,
|
1245 |
+
"nbformat_minor": 4
|
1246 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-20-56-58Z.ipynb
ADDED
@@ -0,0 +1,993 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
8 |
+
"\n",
|
9 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"nteract": {
|
17 |
+
"transient": {
|
18 |
+
"deleting": false
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"tags": []
|
22 |
+
},
|
23 |
+
"outputs": [],
|
24 |
+
"source": [
|
25 |
+
"# %pip install accelerate -U"
|
26 |
+
]
|
27 |
+
},
|
28 |
+
{
|
29 |
+
"cell_type": "code",
|
30 |
+
"execution_count": 2,
|
31 |
+
"metadata": {
|
32 |
+
"nteract": {
|
33 |
+
"transient": {
|
34 |
+
"deleting": false
|
35 |
+
}
|
36 |
+
}
|
37 |
+
},
|
38 |
+
"outputs": [
|
39 |
+
{
|
40 |
+
"name": "stdout",
|
41 |
+
"output_type": "stream",
|
42 |
+
"text": [
|
43 |
+
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
|
44 |
+
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
|
45 |
+
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
|
46 |
+
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
|
47 |
+
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
|
48 |
+
"Requirement already satisfied: scikit-multilearn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.2.0)\n",
|
49 |
+
"Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
|
50 |
+
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
|
51 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
|
52 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
|
53 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
|
54 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
|
55 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
|
56 |
+
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
|
57 |
+
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
|
58 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
|
59 |
+
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
|
60 |
+
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
|
61 |
+
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
|
62 |
+
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
|
63 |
+
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
|
64 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
|
65 |
+
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
|
66 |
+
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
|
67 |
+
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
|
68 |
+
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
|
69 |
+
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
|
70 |
+
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
|
71 |
+
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
|
72 |
+
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
|
73 |
+
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
|
74 |
+
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
|
75 |
+
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
|
76 |
+
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
|
77 |
+
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
|
78 |
+
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
|
79 |
+
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
|
80 |
+
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
|
81 |
+
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
|
82 |
+
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
|
83 |
+
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
|
84 |
+
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
|
85 |
+
"Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
|
86 |
+
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
|
87 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
|
88 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
|
89 |
+
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
90 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
91 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
|
92 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
93 |
+
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
|
94 |
+
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
|
95 |
+
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
|
96 |
+
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
|
97 |
+
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
|
98 |
+
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
|
99 |
+
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
|
100 |
+
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
|
101 |
+
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
|
102 |
+
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
|
103 |
+
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
|
104 |
+
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
|
105 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
|
106 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
|
107 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
|
108 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
|
109 |
+
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
|
110 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n",
|
111 |
+
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
112 |
+
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
113 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
|
114 |
+
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
|
115 |
+
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
|
116 |
+
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
|
117 |
+
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
|
118 |
+
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
|
119 |
+
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
|
120 |
+
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
|
121 |
+
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
|
122 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
123 |
+
]
|
124 |
+
}
|
125 |
+
],
|
126 |
+
"source": [
|
127 |
+
"%pip install transformers datasets shap watermark wandb scikit-multilearn evaluate codecarbon"
|
128 |
+
]
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"execution_count": 3,
|
133 |
+
"metadata": {
|
134 |
+
"datalore": {
|
135 |
+
"hide_input_from_viewers": false,
|
136 |
+
"hide_output_from_viewers": false,
|
137 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
138 |
+
"report_properties": {
|
139 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
140 |
+
},
|
141 |
+
"type": "CODE"
|
142 |
+
},
|
143 |
+
"gather": {
|
144 |
+
"logged": 1706454586481
|
145 |
+
},
|
146 |
+
"tags": []
|
147 |
+
},
|
148 |
+
"outputs": [
|
149 |
+
{
|
150 |
+
"name": "stderr",
|
151 |
+
"output_type": "stream",
|
152 |
+
"text": [
|
153 |
+
"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
154 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
155 |
+
"2024-01-28 19:47:15.508449: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
156 |
+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
157 |
+
"2024-01-28 19:47:16.502791: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
|
158 |
+
"2024-01-28 19:47:16.502915: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
|
159 |
+
"2024-01-28 19:47:16.502928: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
160 |
+
]
|
161 |
+
}
|
162 |
+
],
|
163 |
+
"source": [
|
164 |
+
"import pandas as pd\n",
|
165 |
+
"import numpy as np\n",
|
166 |
+
"import torch\n",
|
167 |
+
"import os\n",
|
168 |
+
"from typing import List, Union\n",
|
169 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
170 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
|
171 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
172 |
+
"from pyarrow import Table\n",
|
173 |
+
"import shap\n",
|
174 |
+
"import wandb\n",
|
175 |
+
"import evaluate\n",
|
176 |
+
"from codecarbon import EmissionsTracker\n",
|
177 |
+
"\n",
|
178 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
179 |
+
"tracker = EmissionsTracker()\n",
|
180 |
+
"\n",
|
181 |
+
"%load_ext watermark"
|
182 |
+
]
|
183 |
+
},
|
184 |
+
{
|
185 |
+
"cell_type": "code",
|
186 |
+
"execution_count": 4,
|
187 |
+
"metadata": {
|
188 |
+
"collapsed": false,
|
189 |
+
"gather": {
|
190 |
+
"logged": 1706454586654
|
191 |
+
},
|
192 |
+
"jupyter": {
|
193 |
+
"outputs_hidden": false
|
194 |
+
}
|
195 |
+
},
|
196 |
+
"outputs": [],
|
197 |
+
"source": [
|
198 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
199 |
+
"\n",
|
200 |
+
"SEED: int = 42\n",
|
201 |
+
"\n",
|
202 |
+
"BATCH_SIZE: int = 16\n",
|
203 |
+
"EPOCHS: int = 3\n",
|
204 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
205 |
+
"\n",
|
206 |
+
"# WandB configuration\n",
|
207 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
208 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
209 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
210 |
+
]
|
211 |
+
},
|
212 |
+
{
|
213 |
+
"cell_type": "code",
|
214 |
+
"execution_count": 5,
|
215 |
+
"metadata": {
|
216 |
+
"collapsed": false,
|
217 |
+
"jupyter": {
|
218 |
+
"outputs_hidden": false
|
219 |
+
}
|
220 |
+
},
|
221 |
+
"outputs": [
|
222 |
+
{
|
223 |
+
"name": "stdout",
|
224 |
+
"output_type": "stream",
|
225 |
+
"text": [
|
226 |
+
"numpy : 1.23.5\n",
|
227 |
+
"re : 2.2.1\n",
|
228 |
+
"evaluate: 0.4.1\n",
|
229 |
+
"pandas : 2.0.2\n",
|
230 |
+
"wandb : 0.16.2\n",
|
231 |
+
"shap : 0.44.1\n",
|
232 |
+
"torch : 1.12.0\n",
|
233 |
+
"logging : 0.5.1.2\n",
|
234 |
+
"\n"
|
235 |
+
]
|
236 |
+
}
|
237 |
+
],
|
238 |
+
"source": [
|
239 |
+
"%watermark --iversion"
|
240 |
+
]
|
241 |
+
},
|
242 |
+
{
|
243 |
+
"cell_type": "code",
|
244 |
+
"execution_count": 6,
|
245 |
+
"metadata": {
|
246 |
+
"datalore": {
|
247 |
+
"hide_input_from_viewers": true,
|
248 |
+
"hide_output_from_viewers": true,
|
249 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
250 |
+
"type": "CODE"
|
251 |
+
}
|
252 |
+
},
|
253 |
+
"outputs": [
|
254 |
+
{
|
255 |
+
"name": "stdout",
|
256 |
+
"output_type": "stream",
|
257 |
+
"text": [
|
258 |
+
"Sun Jan 28 19:47:19 2024 \n",
|
259 |
+
"+---------------------------------------------------------------------------------------+\n",
|
260 |
+
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
|
261 |
+
"|-----------------------------------------+----------------------+----------------------+\n",
|
262 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
263 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
264 |
+
"| | | MIG M. |\n",
|
265 |
+
"|=========================================+======================+======================|\n",
|
266 |
+
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
|
267 |
+
"| N/A 29C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
268 |
+
"| | | N/A |\n",
|
269 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
270 |
+
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
|
271 |
+
"| N/A 29C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
272 |
+
"| | | N/A |\n",
|
273 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
274 |
+
" \n",
|
275 |
+
"+---------------------------------------------------------------------------------------+\n",
|
276 |
+
"| Processes: |\n",
|
277 |
+
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
278 |
+
"| ID ID Usage |\n",
|
279 |
+
"|=======================================================================================|\n",
|
280 |
+
"| No running processes found |\n",
|
281 |
+
"+---------------------------------------------------------------------------------------+\n"
|
282 |
+
]
|
283 |
+
}
|
284 |
+
],
|
285 |
+
"source": [
|
286 |
+
"!nvidia-smi"
|
287 |
+
]
|
288 |
+
},
|
289 |
+
{
|
290 |
+
"cell_type": "markdown",
|
291 |
+
"metadata": {
|
292 |
+
"datalore": {
|
293 |
+
"hide_input_from_viewers": false,
|
294 |
+
"hide_output_from_viewers": false,
|
295 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
296 |
+
"report_properties": {
|
297 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
298 |
+
},
|
299 |
+
"type": "MD"
|
300 |
+
}
|
301 |
+
},
|
302 |
+
"source": [
|
303 |
+
"## Loading the data set"
|
304 |
+
]
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"execution_count": 7,
|
309 |
+
"metadata": {
|
310 |
+
"collapsed": false,
|
311 |
+
"gather": {
|
312 |
+
"logged": 1706449040507
|
313 |
+
},
|
314 |
+
"jupyter": {
|
315 |
+
"outputs_hidden": false
|
316 |
+
}
|
317 |
+
},
|
318 |
+
"outputs": [],
|
319 |
+
"source": [
|
320 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
321 |
+
]
|
322 |
+
},
|
323 |
+
{
|
324 |
+
"cell_type": "code",
|
325 |
+
"execution_count": 8,
|
326 |
+
"metadata": {
|
327 |
+
"collapsed": false,
|
328 |
+
"gather": {
|
329 |
+
"logged": 1706449044205
|
330 |
+
},
|
331 |
+
"jupyter": {
|
332 |
+
"outputs_hidden": false,
|
333 |
+
"source_hidden": false
|
334 |
+
},
|
335 |
+
"nteract": {
|
336 |
+
"transient": {
|
337 |
+
"deleting": false
|
338 |
+
}
|
339 |
+
}
|
340 |
+
},
|
341 |
+
"outputs": [
|
342 |
+
{
|
343 |
+
"data": {
|
344 |
+
"text/plain": [
|
345 |
+
"DatasetDict({\n",
|
346 |
+
" train: Dataset({\n",
|
347 |
+
" features: ['id', 'text', 'label'],\n",
|
348 |
+
" num_rows: 1270444\n",
|
349 |
+
" })\n",
|
350 |
+
" test: Dataset({\n",
|
351 |
+
" features: ['id', 'text', 'label'],\n",
|
352 |
+
" num_rows: 272238\n",
|
353 |
+
" })\n",
|
354 |
+
" val: Dataset({\n",
|
355 |
+
" features: ['id', 'text', 'label'],\n",
|
356 |
+
" num_rows: 272238\n",
|
357 |
+
" })\n",
|
358 |
+
"})"
|
359 |
+
]
|
360 |
+
},
|
361 |
+
"execution_count": 8,
|
362 |
+
"metadata": {},
|
363 |
+
"output_type": "execute_result"
|
364 |
+
}
|
365 |
+
],
|
366 |
+
"source": [
|
367 |
+
"dataset"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": 9,
|
373 |
+
"metadata": {},
|
374 |
+
"outputs": [],
|
375 |
+
"source": [
|
376 |
+
"SUBSAMPLING = 0.1\n",
|
377 |
+
"\n",
|
378 |
+
"if SUBSAMPLING < 1:\n",
|
379 |
+
" _ = DatasetDict()\n",
|
380 |
+
" for each in dataset.keys():\n",
|
381 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
382 |
+
"\n",
|
383 |
+
" dataset = _"
|
384 |
+
]
|
385 |
+
},
|
386 |
+
{
|
387 |
+
"cell_type": "markdown",
|
388 |
+
"metadata": {},
|
389 |
+
"source": [
|
390 |
+
"## Tokenisation and encoding"
|
391 |
+
]
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"cell_type": "code",
|
395 |
+
"execution_count": 10,
|
396 |
+
"metadata": {},
|
397 |
+
"outputs": [],
|
398 |
+
"source": [
|
399 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
400 |
+
" return ds_enc"
|
401 |
+
]
|
402 |
+
},
|
403 |
+
{
|
404 |
+
"cell_type": "markdown",
|
405 |
+
"metadata": {},
|
406 |
+
"source": [
|
407 |
+
"## Evaluation metrics"
|
408 |
+
]
|
409 |
+
},
|
410 |
+
{
|
411 |
+
"cell_type": "code",
|
412 |
+
"execution_count": 11,
|
413 |
+
"metadata": {},
|
414 |
+
"outputs": [],
|
415 |
+
"source": [
|
416 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
417 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
418 |
+
"f1 = evaluate.load(\"f1\")"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"execution_count": 12,
|
424 |
+
"metadata": {},
|
425 |
+
"outputs": [],
|
426 |
+
"source": [
|
427 |
+
"def compute_metrics(eval_pred):\n",
|
428 |
+
" predictions, labels = eval_pred\n",
|
429 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
430 |
+
" return {\n",
|
431 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
432 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
433 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
434 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
435 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
436 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
437 |
+
" }"
|
438 |
+
]
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"cell_type": "markdown",
|
442 |
+
"metadata": {},
|
443 |
+
"source": [
|
444 |
+
"## Training"
|
445 |
+
]
|
446 |
+
},
|
447 |
+
{
|
448 |
+
"cell_type": "markdown",
|
449 |
+
"metadata": {},
|
450 |
+
"source": [
|
451 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "code",
|
456 |
+
"execution_count": 13,
|
457 |
+
"metadata": {},
|
458 |
+
"outputs": [],
|
459 |
+
"source": [
|
460 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
461 |
+
]
|
462 |
+
},
|
463 |
+
{
|
464 |
+
"cell_type": "code",
|
465 |
+
"execution_count": 14,
|
466 |
+
"metadata": {},
|
467 |
+
"outputs": [
|
468 |
+
{
|
469 |
+
"name": "stderr",
|
470 |
+
"output_type": "stream",
|
471 |
+
"text": [
|
472 |
+
"Map: 100%|██████████| 127044/127044 [00:53<00:00, 2384.54 examples/s]\n",
|
473 |
+
"Map: 100%|██████████| 27223/27223 [00:11<00:00, 2396.71 examples/s]\n",
|
474 |
+
"Map: 100%|██████████| 27223/27223 [00:11<00:00, 2375.38 examples/s]\n",
|
475 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
|
476 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
477 |
+
]
|
478 |
+
}
|
479 |
+
],
|
480 |
+
"source": [
|
481 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
482 |
+
"\n",
|
483 |
+
"cols = dataset[\"train\"].column_names\n",
|
484 |
+
"cols.remove(\"label\")\n",
|
485 |
+
"ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n",
|
486 |
+
"\n",
|
487 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
488 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
489 |
+
" id2label=label_map, \n",
|
490 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
491 |
+
"\n",
|
492 |
+
"args = TrainingArguments(\n",
|
493 |
+
" output_dir=\"vaers\",\n",
|
494 |
+
" evaluation_strategy=\"epoch\",\n",
|
495 |
+
" save_strategy=\"epoch\",\n",
|
496 |
+
" learning_rate=2e-5,\n",
|
497 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
498 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
499 |
+
" num_train_epochs=EPOCHS,\n",
|
500 |
+
" weight_decay=.01,\n",
|
501 |
+
" logging_steps=1,\n",
|
502 |
+
" load_best_model_at_end=True,\n",
|
503 |
+
" run_name=f\"daedra-training\",\n",
|
504 |
+
" report_to=[\"wandb\"])\n",
|
505 |
+
"\n",
|
506 |
+
"trainer = Trainer(\n",
|
507 |
+
" model=model,\n",
|
508 |
+
" args=args,\n",
|
509 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
510 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
511 |
+
" tokenizer=tokenizer,\n",
|
512 |
+
" compute_metrics=compute_metrics)"
|
513 |
+
]
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"cell_type": "code",
|
517 |
+
"execution_count": 15,
|
518 |
+
"metadata": {},
|
519 |
+
"outputs": [
|
520 |
+
{
|
521 |
+
"name": "stderr",
|
522 |
+
"output_type": "stream",
|
523 |
+
"text": [
|
524 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
|
525 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"data": {
|
530 |
+
"text/html": [
|
531 |
+
"Tracking run with wandb version 0.16.2"
|
532 |
+
],
|
533 |
+
"text/plain": [
|
534 |
+
"<IPython.core.display.HTML object>"
|
535 |
+
]
|
536 |
+
},
|
537 |
+
"metadata": {},
|
538 |
+
"output_type": "display_data"
|
539 |
+
},
|
540 |
+
{
|
541 |
+
"data": {
|
542 |
+
"text/html": [
|
543 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_194842-yvxddyg6</code>"
|
544 |
+
],
|
545 |
+
"text/plain": [
|
546 |
+
"<IPython.core.display.HTML object>"
|
547 |
+
]
|
548 |
+
},
|
549 |
+
"metadata": {},
|
550 |
+
"output_type": "display_data"
|
551 |
+
},
|
552 |
+
{
|
553 |
+
"data": {
|
554 |
+
"text/html": [
|
555 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
556 |
+
],
|
557 |
+
"text/plain": [
|
558 |
+
"<IPython.core.display.HTML object>"
|
559 |
+
]
|
560 |
+
},
|
561 |
+
"metadata": {},
|
562 |
+
"output_type": "display_data"
|
563 |
+
},
|
564 |
+
{
|
565 |
+
"data": {
|
566 |
+
"text/html": [
|
567 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
568 |
+
],
|
569 |
+
"text/plain": [
|
570 |
+
"<IPython.core.display.HTML object>"
|
571 |
+
]
|
572 |
+
},
|
573 |
+
"metadata": {},
|
574 |
+
"output_type": "display_data"
|
575 |
+
},
|
576 |
+
{
|
577 |
+
"data": {
|
578 |
+
"text/html": [
|
579 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6</a>"
|
580 |
+
],
|
581 |
+
"text/plain": [
|
582 |
+
"<IPython.core.display.HTML object>"
|
583 |
+
]
|
584 |
+
},
|
585 |
+
"metadata": {},
|
586 |
+
"output_type": "display_data"
|
587 |
+
},
|
588 |
+
{
|
589 |
+
"data": {
|
590 |
+
"text/html": [
|
591 |
+
"Finishing last run (ID:yvxddyg6) before initializing another..."
|
592 |
+
],
|
593 |
+
"text/plain": [
|
594 |
+
"<IPython.core.display.HTML object>"
|
595 |
+
]
|
596 |
+
},
|
597 |
+
"metadata": {},
|
598 |
+
"output_type": "display_data"
|
599 |
+
},
|
600 |
+
{
|
601 |
+
"data": {
|
602 |
+
"text/html": [
|
603 |
+
" View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/yvxddyg6</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
604 |
+
],
|
605 |
+
"text/plain": [
|
606 |
+
"<IPython.core.display.HTML object>"
|
607 |
+
]
|
608 |
+
},
|
609 |
+
"metadata": {},
|
610 |
+
"output_type": "display_data"
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"data": {
|
614 |
+
"text/html": [
|
615 |
+
"Find logs at: <code>./wandb/run-20240128_194842-yvxddyg6/logs</code>"
|
616 |
+
],
|
617 |
+
"text/plain": [
|
618 |
+
"<IPython.core.display.HTML object>"
|
619 |
+
]
|
620 |
+
},
|
621 |
+
"metadata": {},
|
622 |
+
"output_type": "display_data"
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"data": {
|
626 |
+
"text/html": [
|
627 |
+
"Successfully finished last run (ID:yvxddyg6). Initializing new run:<br/>"
|
628 |
+
],
|
629 |
+
"text/plain": [
|
630 |
+
"<IPython.core.display.HTML object>"
|
631 |
+
]
|
632 |
+
},
|
633 |
+
"metadata": {},
|
634 |
+
"output_type": "display_data"
|
635 |
+
},
|
636 |
+
{
|
637 |
+
"data": {
|
638 |
+
"text/html": [
|
639 |
+
"Tracking run with wandb version 0.16.2"
|
640 |
+
],
|
641 |
+
"text/plain": [
|
642 |
+
"<IPython.core.display.HTML object>"
|
643 |
+
]
|
644 |
+
},
|
645 |
+
"metadata": {},
|
646 |
+
"output_type": "display_data"
|
647 |
+
},
|
648 |
+
{
|
649 |
+
"data": {
|
650 |
+
"text/html": [
|
651 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_194845-9g8te2gf</code>"
|
652 |
+
],
|
653 |
+
"text/plain": [
|
654 |
+
"<IPython.core.display.HTML object>"
|
655 |
+
]
|
656 |
+
},
|
657 |
+
"metadata": {},
|
658 |
+
"output_type": "display_data"
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"data": {
|
662 |
+
"text/html": [
|
663 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
664 |
+
],
|
665 |
+
"text/plain": [
|
666 |
+
"<IPython.core.display.HTML object>"
|
667 |
+
]
|
668 |
+
},
|
669 |
+
"metadata": {},
|
670 |
+
"output_type": "display_data"
|
671 |
+
},
|
672 |
+
{
|
673 |
+
"data": {
|
674 |
+
"text/html": [
|
675 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
676 |
+
],
|
677 |
+
"text/plain": [
|
678 |
+
"<IPython.core.display.HTML object>"
|
679 |
+
]
|
680 |
+
},
|
681 |
+
"metadata": {},
|
682 |
+
"output_type": "display_data"
|
683 |
+
},
|
684 |
+
{
|
685 |
+
"data": {
|
686 |
+
"text/html": [
|
687 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf</a>"
|
688 |
+
],
|
689 |
+
"text/plain": [
|
690 |
+
"<IPython.core.display.HTML object>"
|
691 |
+
]
|
692 |
+
},
|
693 |
+
"metadata": {},
|
694 |
+
"output_type": "display_data"
|
695 |
+
},
|
696 |
+
{
|
697 |
+
"data": {
|
698 |
+
"text/html": [
|
699 |
+
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/9g8te2gf?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
|
700 |
+
],
|
701 |
+
"text/plain": [
|
702 |
+
"<wandb.sdk.wandb_run.Run at 0x7fb8b483bf40>"
|
703 |
+
]
|
704 |
+
},
|
705 |
+
"execution_count": 15,
|
706 |
+
"metadata": {},
|
707 |
+
"output_type": "execute_result"
|
708 |
+
}
|
709 |
+
],
|
710 |
+
"source": [
|
711 |
+
"if SUBSAMPLING != 1.0:\n",
|
712 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
713 |
+
"else:\n",
|
714 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
715 |
+
"\n",
|
716 |
+
"wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
717 |
+
"wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
718 |
+
" \n",
|
719 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
|
720 |
+
]
|
721 |
+
},
|
722 |
+
{
|
723 |
+
"cell_type": "code",
|
724 |
+
"execution_count": 16,
|
725 |
+
"metadata": {},
|
726 |
+
"outputs": [
|
727 |
+
{
|
728 |
+
"name": "stderr",
|
729 |
+
"output_type": "stream",
|
730 |
+
"text": [
|
731 |
+
"Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
732 |
+
]
|
733 |
+
},
|
734 |
+
{
|
735 |
+
"data": {
|
736 |
+
"text/html": [
|
737 |
+
"\n",
|
738 |
+
" <div>\n",
|
739 |
+
" \n",
|
740 |
+
" <progress value='7943' max='11913' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
741 |
+
" [ 7943/11913 43:43 < 21:51, 3.03 it/s, Epoch 2/3]\n",
|
742 |
+
" </div>\n",
|
743 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
744 |
+
" <thead>\n",
|
745 |
+
" <tr style=\"text-align: left;\">\n",
|
746 |
+
" <th>Epoch</th>\n",
|
747 |
+
" <th>Training Loss</th>\n",
|
748 |
+
" <th>Validation Loss</th>\n",
|
749 |
+
" <th>Accuracy</th>\n",
|
750 |
+
" <th>Precision Macroaverage</th>\n",
|
751 |
+
" <th>Precision Microaverage</th>\n",
|
752 |
+
" <th>Recall Macroaverage</th>\n",
|
753 |
+
" <th>Recall Microaverage</th>\n",
|
754 |
+
" <th>F1 Microaverage</th>\n",
|
755 |
+
" </tr>\n",
|
756 |
+
" </thead>\n",
|
757 |
+
" <tbody>\n",
|
758 |
+
" <tr>\n",
|
759 |
+
" <td>1</td>\n",
|
760 |
+
" <td>0.251300</td>\n",
|
761 |
+
" <td>0.362917</td>\n",
|
762 |
+
" <td>0.865775</td>\n",
|
763 |
+
" <td>0.701081</td>\n",
|
764 |
+
" <td>0.865775</td>\n",
|
765 |
+
" <td>0.556570</td>\n",
|
766 |
+
" <td>0.865775</td>\n",
|
767 |
+
" <td>0.865775</td>\n",
|
768 |
+
" </tr>\n",
|
769 |
+
" <tr>\n",
|
770 |
+
" <td>2</td>\n",
|
771 |
+
" <td>0.036000</td>\n",
|
772 |
+
" <td>0.352118</td>\n",
|
773 |
+
" <td>0.870551</td>\n",
|
774 |
+
" <td>0.728051</td>\n",
|
775 |
+
" <td>0.870551</td>\n",
|
776 |
+
" <td>0.609787</td>\n",
|
777 |
+
" <td>0.870551</td>\n",
|
778 |
+
" <td>0.870551</td>\n",
|
779 |
+
" </tr>\n",
|
780 |
+
" </tbody>\n",
|
781 |
+
"</table><p>"
|
782 |
+
],
|
783 |
+
"text/plain": [
|
784 |
+
"<IPython.core.display.HTML object>"
|
785 |
+
]
|
786 |
+
},
|
787 |
+
"metadata": {},
|
788 |
+
"output_type": "display_data"
|
789 |
+
},
|
790 |
+
{
|
791 |
+
"name": "stderr",
|
792 |
+
"output_type": "stream",
|
793 |
+
"text": [
|
794 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3971)... Done. 18.2s\n",
|
795 |
+
"Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n",
|
796 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-7942)... "
|
797 |
+
]
|
798 |
+
}
|
799 |
+
],
|
800 |
+
"source": [
|
801 |
+
"tracker.start()\n",
|
802 |
+
"trainer.train()\n",
|
803 |
+
"tracker.stop()\n"
|
804 |
+
]
|
805 |
+
},
|
806 |
+
{
|
807 |
+
"cell_type": "code",
|
808 |
+
"execution_count": null,
|
809 |
+
"metadata": {},
|
810 |
+
"outputs": [
|
811 |
+
{
|
812 |
+
"data": {
|
813 |
+
"text/html": [
|
814 |
+
"<style>\n",
|
815 |
+
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
|
816 |
+
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
|
817 |
+
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
|
818 |
+
" </style>\n",
|
819 |
+
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>▁▇█</td></tr><tr><td>eval/f1_microaverage</td><td>▁▇█</td></tr><tr><td>eval/loss</td><td>█▃▁</td></tr><tr><td>eval/precision_macroaverage</td><td>▁▇█</td></tr><tr><td>eval/precision_microaverage</td><td>▁▇█</td></tr><tr><td>eval/recall_macroaverage</td><td>▁▇█</td></tr><tr><td>eval/recall_microaverage</td><td>▁▇█</td></tr><tr><td>eval/runtime</td><td>▁▃█</td></tr><tr><td>eval/samples_per_second</td><td>█▆▁</td></tr><tr><td>eval/steps_per_second</td><td>█▆▁</td></tr><tr><td>train/epoch</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▄▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>train/global_step</td><td>▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███</td></tr><tr><td>train/learning_rate</td><td>████▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁</td></tr><tr><td>train/loss</td><td>█▅▆▆▅▄▄▃▆▅▃▃▅▄▆▄▄▄▂▄▄▅▄▃▄▄▁▄▂▂▃▃▃▂▂▃▂▃▃▂</td></tr><tr><td>train/total_flos</td><td>▁</td></tr><tr><td>train/train_loss</td><td>▁</td></tr><tr><td>train/train_runtime</td><td>▁</td></tr><tr><td>train/train_samples_per_second</td><td>▁</td></tr><tr><td>train/train_steps_per_second</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy</td><td>0.84019</td></tr><tr><td>eval/f1_microaverage</td><td>0.84019</td></tr><tr><td>eval/loss</td><td>0.44011</td></tr><tr><td>eval/precision_macroaverage</td><td>0.415</td></tr><tr><td>eval/precision_microaverage</td><td>0.84019</td></tr><tr><td>eval/recall_macroaverage</td><td>0.40704</td></tr><tr><td>eval/recall_microaverage</td><td>0.84019</td></tr><tr><td>eval/runtime</td><td>10.0118</td></tr><tr><td>eval/samples_per_second</td><td>271.878</td></tr><tr><td>eval/steps_per_second</td><td>8.59</td></tr><tr><td>train/epoch</td><td>3.0</td></tr><tr><td>train/global_step</td><td>1191</td></tr><tr><td>train/learning_rate</td><td>0.0</td></tr><tr><td>train/loss</td><td>0.1782</td></tr><tr><td>train/total_flos</td><td>4885522962505728.0</td></tr><tr><td>train/train_loss</td><td>0.4724</td></tr><tr><td>train/train_runtime</td><td>483.5027</td></tr><tr><td>train/train_samples_per_second</td><td>78.825</td></tr><tr><td>train/train_steps_per_second</td><td>2.463</td></tr></table><br/></div></div>"
|
820 |
+
],
|
821 |
+
"text/plain": [
|
822 |
+
"<IPython.core.display.HTML object>"
|
823 |
+
]
|
824 |
+
},
|
825 |
+
"metadata": {},
|
826 |
+
"output_type": "display_data"
|
827 |
+
},
|
828 |
+
{
|
829 |
+
"data": {
|
830 |
+
"text/html": [
|
831 |
+
" View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/3xvt3c2y' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/3xvt3c2y</a><br/>Synced 5 W&B file(s), 0 media file(s), 40 artifact file(s) and 0 other file(s)"
|
832 |
+
],
|
833 |
+
"text/plain": [
|
834 |
+
"<IPython.core.display.HTML object>"
|
835 |
+
]
|
836 |
+
},
|
837 |
+
"metadata": {},
|
838 |
+
"output_type": "display_data"
|
839 |
+
},
|
840 |
+
{
|
841 |
+
"data": {
|
842 |
+
"text/html": [
|
843 |
+
"Find logs at: <code>./wandb/run-20240128_192000-3xvt3c2y/logs</code>"
|
844 |
+
],
|
845 |
+
"text/plain": [
|
846 |
+
"<IPython.core.display.HTML object>"
|
847 |
+
]
|
848 |
+
},
|
849 |
+
"metadata": {},
|
850 |
+
"output_type": "display_data"
|
851 |
+
}
|
852 |
+
],
|
853 |
+
"source": [
|
854 |
+
"wandb.finish()"
|
855 |
+
]
|
856 |
+
},
|
857 |
+
{
|
858 |
+
"cell_type": "code",
|
859 |
+
"execution_count": null,
|
860 |
+
"metadata": {},
|
861 |
+
"outputs": [
|
862 |
+
{
|
863 |
+
"data": {
|
864 |
+
"text/plain": [
|
865 |
+
"CommitInfo(commit_url='https://huggingface.co/chrisvoncsefalvay/daedra/commit/c482ca6c8520142a3e67df4be25a408e6b557053', commit_message='DAEDRA model trained on 1.0% of the full sample of the VAERS dataset (training set size: 12,704)', commit_description='', oid='c482ca6c8520142a3e67df4be25a408e6b557053', pr_url=None, pr_revision=None, pr_num=None)"
|
866 |
+
]
|
867 |
+
},
|
868 |
+
"execution_count": 31,
|
869 |
+
"metadata": {},
|
870 |
+
"output_type": "execute_result"
|
871 |
+
}
|
872 |
+
],
|
873 |
+
"source": [
|
874 |
+
"variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
875 |
+
"tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
876 |
+
"tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
877 |
+
"sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
878 |
+
"\n",
|
879 |
+
"model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
880 |
+
" variant=variant,\n",
|
881 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
|
882 |
+
]
|
883 |
+
},
|
884 |
+
{
|
885 |
+
"cell_type": "code",
|
886 |
+
"execution_count": null,
|
887 |
+
"metadata": {},
|
888 |
+
"outputs": [],
|
889 |
+
"source": [
|
890 |
+
"from collections import Counter\n",
|
891 |
+
"\n",
|
892 |
+
"def get_most_frequent_unknown_tokens(tokenizer, dataset):\n",
|
893 |
+
" unknown_tokens = []\n",
|
894 |
+
" \n",
|
895 |
+
" # Tokenize each text in the dataset\n",
|
896 |
+
" for example in dataset:\n",
|
897 |
+
" tokens = tokenizer.tokenize(example['text'])\n",
|
898 |
+
" \n",
|
899 |
+
" # Check if each token is the 'unknown' special token\n",
|
900 |
+
" for token in tokens:\n",
|
901 |
+
" if token == tokenizer.unk_token:\n",
|
902 |
+
" unknown_tokens.append(token)\n",
|
903 |
+
" \n",
|
904 |
+
" # Count the frequency of each unique unknown token\n",
|
905 |
+
" token_counts = Counter(unknown_tokens)\n",
|
906 |
+
" \n",
|
907 |
+
" # Sort the tokens based on their frequency in descending order\n",
|
908 |
+
" most_frequent_tokens = token_counts.most_common()\n",
|
909 |
+
" \n",
|
910 |
+
" return most_frequent_tokens\n",
|
911 |
+
"\n",
|
912 |
+
"# Example usage\n",
|
913 |
+
"tokenizer = YourTokenizer() # Replace with your tokenizer\n",
|
914 |
+
"dataset = YourDataset() # Replace with your dataset\n",
|
915 |
+
"\n",
|
916 |
+
"most_frequent_unknown_tokens = get_most_frequent_unknown_tokens(tokenizer, dataset)\n",
|
917 |
+
"print(most_frequent_unknown_tokens)\n"
|
918 |
+
]
|
919 |
+
}
|
920 |
+
],
|
921 |
+
"metadata": {
|
922 |
+
"datalore": {
|
923 |
+
"base_environment": "default",
|
924 |
+
"computation_mode": "JUPYTER",
|
925 |
+
"package_manager": "pip",
|
926 |
+
"packages": [
|
927 |
+
{
|
928 |
+
"name": "datasets",
|
929 |
+
"source": "PIP",
|
930 |
+
"version": "2.16.1"
|
931 |
+
},
|
932 |
+
{
|
933 |
+
"name": "torch",
|
934 |
+
"source": "PIP",
|
935 |
+
"version": "2.1.2"
|
936 |
+
},
|
937 |
+
{
|
938 |
+
"name": "accelerate",
|
939 |
+
"source": "PIP",
|
940 |
+
"version": "0.26.1"
|
941 |
+
}
|
942 |
+
],
|
943 |
+
"report_row_ids": [
|
944 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
945 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
946 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
947 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
948 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
949 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
950 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
951 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
952 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
953 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
954 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
955 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
956 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
957 |
+
],
|
958 |
+
"version": 3
|
959 |
+
},
|
960 |
+
"kernelspec": {
|
961 |
+
"display_name": "Python 3.8 - Pytorch and Tensorflow",
|
962 |
+
"language": "python",
|
963 |
+
"name": "python38-azureml-pt-tf"
|
964 |
+
},
|
965 |
+
"language_info": {
|
966 |
+
"codemirror_mode": {
|
967 |
+
"name": "ipython",
|
968 |
+
"version": 3
|
969 |
+
},
|
970 |
+
"file_extension": ".py",
|
971 |
+
"mimetype": "text/x-python",
|
972 |
+
"name": "python",
|
973 |
+
"nbconvert_exporter": "python",
|
974 |
+
"pygments_lexer": "ipython3",
|
975 |
+
"version": "3.8.5"
|
976 |
+
},
|
977 |
+
"microsoft": {
|
978 |
+
"host": {
|
979 |
+
"AzureML": {
|
980 |
+
"notebookHasBeenCompleted": true
|
981 |
+
}
|
982 |
+
},
|
983 |
+
"ms_spell_check": {
|
984 |
+
"ms_spell_check_language": "en"
|
985 |
+
}
|
986 |
+
},
|
987 |
+
"nteract": {
|
988 |
+
"version": "nteract-front-end@1.0.0"
|
989 |
+
}
|
990 |
+
},
|
991 |
+
"nbformat": 4,
|
992 |
+
"nbformat_minor": 4
|
993 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-23-54-39Z.ipynb
ADDED
@@ -0,0 +1,692 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {}
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"source": [
|
15 |
+
"%pip install accelerate -U"
|
16 |
+
],
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"output_type": "stream",
|
20 |
+
"name": "stdout",
|
21 |
+
"text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"execution_count": 1,
|
25 |
+
"metadata": {
|
26 |
+
"nteract": {
|
27 |
+
"transient": {
|
28 |
+
"deleting": false
|
29 |
+
}
|
30 |
+
},
|
31 |
+
"tags": [],
|
32 |
+
"gather": {
|
33 |
+
"logged": 1706475754655
|
34 |
+
}
|
35 |
+
}
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"source": [
|
40 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
41 |
+
],
|
42 |
+
"outputs": [
|
43 |
+
{
|
44 |
+
"output_type": "stream",
|
45 |
+
"name": "stdout",
|
46 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nNote: you may need to restart the kernel to use updated packages.\n"
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"execution_count": 2,
|
50 |
+
"metadata": {
|
51 |
+
"nteract": {
|
52 |
+
"transient": {
|
53 |
+
"deleting": false
|
54 |
+
}
|
55 |
+
}
|
56 |
+
}
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"import pandas as pd\n",
|
62 |
+
"import numpy as np\n",
|
63 |
+
"import torch\n",
|
64 |
+
"import os\n",
|
65 |
+
"from typing import List, Union\n",
|
66 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
|
67 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
68 |
+
"import shap\n",
|
69 |
+
"import wandb\n",
|
70 |
+
"import evaluate\n",
|
71 |
+
"from codecarbon import EmissionsTracker\n",
|
72 |
+
"\n",
|
73 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
74 |
+
"tracker = EmissionsTracker()\n",
|
75 |
+
"\n",
|
76 |
+
"%load_ext watermark"
|
77 |
+
],
|
78 |
+
"outputs": [
|
79 |
+
{
|
80 |
+
"output_type": "stream",
|
81 |
+
"name": "stderr",
|
82 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 21:14:33.562898: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 21:14:34.581816: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 21:14:34.581943: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 21:14:34.581956: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n[codecarbon INFO @ 21:14:37] [setup] RAM Tracking...\n[codecarbon INFO @ 21:14:37] [setup] GPU Tracking...\n[codecarbon INFO @ 21:14:37] Tracking Nvidia GPU via pynvml\n[codecarbon INFO @ 21:14:37] [setup] CPU Tracking...\n[codecarbon WARNING @ 21:14:37] No CPU tracking mode found. Falling back on CPU constant mode.\n[codecarbon WARNING @ 21:14:38] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n[codecarbon INFO @ 21:14:38] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 21:14:38] >>> Tracker's metadata:\n[codecarbon INFO @ 21:14:38] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n[codecarbon INFO @ 21:14:38] Python version: 3.8.5\n[codecarbon INFO @ 21:14:38] CodeCarbon version: 2.3.3\n[codecarbon INFO @ 21:14:38] Available RAM : 440.883 GB\n[codecarbon INFO @ 21:14:38] CPU count: 24\n[codecarbon INFO @ 21:14:38] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 21:14:38] GPU count: 4\n[codecarbon INFO @ 21:14:38] GPU model: 4 x Tesla V100-PCIE-16GB\n[codecarbon WARNING @ 21:14:38] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
|
83 |
+
}
|
84 |
+
],
|
85 |
+
"execution_count": 3,
|
86 |
+
"metadata": {
|
87 |
+
"datalore": {
|
88 |
+
"hide_input_from_viewers": false,
|
89 |
+
"hide_output_from_viewers": false,
|
90 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
91 |
+
"report_properties": {
|
92 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
93 |
+
},
|
94 |
+
"type": "CODE"
|
95 |
+
},
|
96 |
+
"gather": {
|
97 |
+
"logged": 1706476478659
|
98 |
+
},
|
99 |
+
"tags": []
|
100 |
+
}
|
101 |
+
},
|
102 |
+
{
|
103 |
+
"cell_type": "code",
|
104 |
+
"source": [
|
105 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
106 |
+
"\n",
|
107 |
+
"SEED: int = 42\n",
|
108 |
+
"\n",
|
109 |
+
"BATCH_SIZE: int = 32\n",
|
110 |
+
"EPOCHS: int = 3\n",
|
111 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
112 |
+
"\n",
|
113 |
+
"# WandB configuration\n",
|
114 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
115 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
116 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
117 |
+
],
|
118 |
+
"outputs": [],
|
119 |
+
"execution_count": 4,
|
120 |
+
"metadata": {
|
121 |
+
"collapsed": false,
|
122 |
+
"gather": {
|
123 |
+
"logged": 1706476478863
|
124 |
+
},
|
125 |
+
"jupyter": {
|
126 |
+
"outputs_hidden": false
|
127 |
+
}
|
128 |
+
}
|
129 |
+
},
|
130 |
+
{
|
131 |
+
"cell_type": "code",
|
132 |
+
"source": [
|
133 |
+
"%watermark --iversion"
|
134 |
+
],
|
135 |
+
"outputs": [
|
136 |
+
{
|
137 |
+
"output_type": "stream",
|
138 |
+
"name": "stdout",
|
139 |
+
"text": "shap : 0.44.1\nre : 2.2.1\ntorch : 1.12.0\nevaluate: 0.4.1\nwandb : 0.16.2\nlogging : 0.5.1.2\npandas : 2.0.2\nnumpy : 1.23.5\n\n"
|
140 |
+
}
|
141 |
+
],
|
142 |
+
"execution_count": 5,
|
143 |
+
"metadata": {
|
144 |
+
"collapsed": false,
|
145 |
+
"jupyter": {
|
146 |
+
"outputs_hidden": false
|
147 |
+
}
|
148 |
+
}
|
149 |
+
},
|
150 |
+
{
|
151 |
+
"cell_type": "code",
|
152 |
+
"source": [
|
153 |
+
"!nvidia-smi"
|
154 |
+
],
|
155 |
+
"outputs": [
|
156 |
+
{
|
157 |
+
"output_type": "stream",
|
158 |
+
"name": "stdout",
|
159 |
+
"text": "Sun Jan 28 21:14:38 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"execution_count": 6,
|
163 |
+
"metadata": {
|
164 |
+
"datalore": {
|
165 |
+
"hide_input_from_viewers": true,
|
166 |
+
"hide_output_from_viewers": true,
|
167 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
168 |
+
"type": "CODE"
|
169 |
+
}
|
170 |
+
}
|
171 |
+
},
|
172 |
+
{
|
173 |
+
"cell_type": "markdown",
|
174 |
+
"source": [
|
175 |
+
"## Loading the data set"
|
176 |
+
],
|
177 |
+
"metadata": {
|
178 |
+
"datalore": {
|
179 |
+
"hide_input_from_viewers": false,
|
180 |
+
"hide_output_from_viewers": false,
|
181 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
182 |
+
"report_properties": {
|
183 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
184 |
+
},
|
185 |
+
"type": "MD"
|
186 |
+
}
|
187 |
+
}
|
188 |
+
},
|
189 |
+
{
|
190 |
+
"cell_type": "code",
|
191 |
+
"source": [
|
192 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
193 |
+
],
|
194 |
+
"outputs": [],
|
195 |
+
"execution_count": 7,
|
196 |
+
"metadata": {
|
197 |
+
"collapsed": false,
|
198 |
+
"gather": {
|
199 |
+
"logged": 1706476480469
|
200 |
+
},
|
201 |
+
"jupyter": {
|
202 |
+
"outputs_hidden": false
|
203 |
+
}
|
204 |
+
}
|
205 |
+
},
|
206 |
+
{
|
207 |
+
"cell_type": "code",
|
208 |
+
"source": [
|
209 |
+
"dataset"
|
210 |
+
],
|
211 |
+
"outputs": [
|
212 |
+
{
|
213 |
+
"output_type": "execute_result",
|
214 |
+
"execution_count": 8,
|
215 |
+
"data": {
|
216 |
+
"text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
|
217 |
+
},
|
218 |
+
"metadata": {}
|
219 |
+
}
|
220 |
+
],
|
221 |
+
"execution_count": 8,
|
222 |
+
"metadata": {
|
223 |
+
"collapsed": false,
|
224 |
+
"gather": {
|
225 |
+
"logged": 1706476480629
|
226 |
+
},
|
227 |
+
"jupyter": {
|
228 |
+
"outputs_hidden": false,
|
229 |
+
"source_hidden": false
|
230 |
+
},
|
231 |
+
"nteract": {
|
232 |
+
"transient": {
|
233 |
+
"deleting": false
|
234 |
+
}
|
235 |
+
}
|
236 |
+
}
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"source": [
|
241 |
+
"SUBSAMPLING = 0.5\n",
|
242 |
+
"\n",
|
243 |
+
"if SUBSAMPLING < 1:\n",
|
244 |
+
" _ = DatasetDict()\n",
|
245 |
+
" for each in dataset.keys():\n",
|
246 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
247 |
+
"\n",
|
248 |
+
" dataset = _"
|
249 |
+
],
|
250 |
+
"outputs": [],
|
251 |
+
"execution_count": 9,
|
252 |
+
"metadata": {
|
253 |
+
"gather": {
|
254 |
+
"logged": 1706476480826
|
255 |
+
}
|
256 |
+
}
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"cell_type": "markdown",
|
260 |
+
"source": [
|
261 |
+
"## Tokenisation and encoding"
|
262 |
+
],
|
263 |
+
"metadata": {}
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"source": [
|
268 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
269 |
+
" return ds_enc"
|
270 |
+
],
|
271 |
+
"outputs": [],
|
272 |
+
"execution_count": 10,
|
273 |
+
"metadata": {
|
274 |
+
"gather": {
|
275 |
+
"logged": 1706476480944
|
276 |
+
}
|
277 |
+
}
|
278 |
+
},
|
279 |
+
{
|
280 |
+
"cell_type": "markdown",
|
281 |
+
"source": [
|
282 |
+
"## Evaluation metrics"
|
283 |
+
],
|
284 |
+
"metadata": {}
|
285 |
+
},
|
286 |
+
{
|
287 |
+
"cell_type": "code",
|
288 |
+
"source": [
|
289 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
290 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
291 |
+
"f1 = evaluate.load(\"f1\")"
|
292 |
+
],
|
293 |
+
"outputs": [],
|
294 |
+
"execution_count": 11,
|
295 |
+
"metadata": {
|
296 |
+
"gather": {
|
297 |
+
"logged": 1706476481192
|
298 |
+
}
|
299 |
+
}
|
300 |
+
},
|
301 |
+
{
|
302 |
+
"cell_type": "code",
|
303 |
+
"source": [
|
304 |
+
"def compute_metrics(eval_pred):\n",
|
305 |
+
" predictions, labels = eval_pred\n",
|
306 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
307 |
+
" return {\n",
|
308 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
309 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
310 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
311 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
312 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
313 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
314 |
+
" }"
|
315 |
+
],
|
316 |
+
"outputs": [],
|
317 |
+
"execution_count": 12,
|
318 |
+
"metadata": {
|
319 |
+
"gather": {
|
320 |
+
"logged": 1706476481346
|
321 |
+
}
|
322 |
+
}
|
323 |
+
},
|
324 |
+
{
|
325 |
+
"cell_type": "markdown",
|
326 |
+
"source": [
|
327 |
+
"## Training"
|
328 |
+
],
|
329 |
+
"metadata": {}
|
330 |
+
},
|
331 |
+
{
|
332 |
+
"cell_type": "markdown",
|
333 |
+
"source": [
|
334 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
335 |
+
],
|
336 |
+
"metadata": {}
|
337 |
+
},
|
338 |
+
{
|
339 |
+
"cell_type": "code",
|
340 |
+
"source": [
|
341 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
342 |
+
],
|
343 |
+
"outputs": [],
|
344 |
+
"execution_count": 13,
|
345 |
+
"metadata": {
|
346 |
+
"gather": {
|
347 |
+
"logged": 1706476481593
|
348 |
+
}
|
349 |
+
}
|
350 |
+
},
|
351 |
+
{
|
352 |
+
"cell_type": "code",
|
353 |
+
"source": [
|
354 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
355 |
+
"\n",
|
356 |
+
"cols = dataset[\"train\"].column_names\n",
|
357 |
+
"cols.remove(\"label\")\n",
|
358 |
+
"ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n",
|
359 |
+
"\n",
|
360 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
361 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
362 |
+
" id2label=label_map, \n",
|
363 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
364 |
+
"\n",
|
365 |
+
"args = TrainingArguments(\n",
|
366 |
+
" output_dir=\"vaers\",\n",
|
367 |
+
" evaluation_strategy=\"epoch\",\n",
|
368 |
+
" save_strategy=\"epoch\",\n",
|
369 |
+
" learning_rate=2e-5,\n",
|
370 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
371 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
372 |
+
" num_train_epochs=EPOCHS,\n",
|
373 |
+
" weight_decay=.01,\n",
|
374 |
+
" logging_steps=1,\n",
|
375 |
+
" load_best_model_at_end=True,\n",
|
376 |
+
" run_name=f\"daedra-training\",\n",
|
377 |
+
" report_to=[\"wandb\"])\n",
|
378 |
+
"\n",
|
379 |
+
"trainer = Trainer(\n",
|
380 |
+
" model=model,\n",
|
381 |
+
" args=args,\n",
|
382 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
383 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
384 |
+
" tokenizer=tokenizer,\n",
|
385 |
+
" compute_metrics=compute_metrics)"
|
386 |
+
],
|
387 |
+
"outputs": [
|
388 |
+
{
|
389 |
+
"output_type": "stream",
|
390 |
+
"name": "stderr",
|
391 |
+
"text": "Map: 100%|██████████| 635222/635222 [04:25<00:00, 2395.47 examples/s]\nMap: 100%|██████████| 136119/136119 [00:56<00:00, 2405.75 examples/s]\nMap: 100%|██████████| 136119/136119 [00:56<00:00, 2422.27 examples/s]\nSome weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
392 |
+
}
|
393 |
+
],
|
394 |
+
"execution_count": 14,
|
395 |
+
"metadata": {
|
396 |
+
"gather": {
|
397 |
+
"logged": 1706476861739
|
398 |
+
}
|
399 |
+
}
|
400 |
+
},
|
401 |
+
{
|
402 |
+
"cell_type": "code",
|
403 |
+
"source": [
|
404 |
+
"if SUBSAMPLING != 1.0:\n",
|
405 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
406 |
+
"else:\n",
|
407 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
408 |
+
"\n",
|
409 |
+
"wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
410 |
+
"wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
411 |
+
" \n",
|
412 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
|
413 |
+
],
|
414 |
+
"outputs": [
|
415 |
+
{
|
416 |
+
"output_type": "stream",
|
417 |
+
"name": "stderr",
|
418 |
+
"text": "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
419 |
+
},
|
420 |
+
{
|
421 |
+
"output_type": "display_data",
|
422 |
+
"data": {
|
423 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
424 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
425 |
+
},
|
426 |
+
"metadata": {}
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"output_type": "display_data",
|
430 |
+
"data": {
|
431 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
432 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_212103-403j5ij5</code>"
|
433 |
+
},
|
434 |
+
"metadata": {}
|
435 |
+
},
|
436 |
+
{
|
437 |
+
"output_type": "display_data",
|
438 |
+
"data": {
|
439 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
440 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
441 |
+
},
|
442 |
+
"metadata": {}
|
443 |
+
},
|
444 |
+
{
|
445 |
+
"output_type": "display_data",
|
446 |
+
"data": {
|
447 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
448 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
449 |
+
},
|
450 |
+
"metadata": {}
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"output_type": "display_data",
|
454 |
+
"data": {
|
455 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
456 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5</a>"
|
457 |
+
},
|
458 |
+
"metadata": {}
|
459 |
+
},
|
460 |
+
{
|
461 |
+
"output_type": "display_data",
|
462 |
+
"data": {
|
463 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
464 |
+
"text/html": "Finishing last run (ID:403j5ij5) before initializing another..."
|
465 |
+
},
|
466 |
+
"metadata": {}
|
467 |
+
},
|
468 |
+
{
|
469 |
+
"output_type": "display_data",
|
470 |
+
"data": {
|
471 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
472 |
+
"text/html": " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/403j5ij5</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
473 |
+
},
|
474 |
+
"metadata": {}
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"output_type": "display_data",
|
478 |
+
"data": {
|
479 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
480 |
+
"text/html": "Find logs at: <code>./wandb/run-20240128_212103-403j5ij5/logs</code>"
|
481 |
+
},
|
482 |
+
"metadata": {}
|
483 |
+
},
|
484 |
+
{
|
485 |
+
"output_type": "display_data",
|
486 |
+
"data": {
|
487 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
488 |
+
"text/html": "Successfully finished last run (ID:403j5ij5). Initializing new run:<br/>"
|
489 |
+
},
|
490 |
+
"metadata": {}
|
491 |
+
},
|
492 |
+
{
|
493 |
+
"output_type": "display_data",
|
494 |
+
"data": {
|
495 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
496 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
497 |
+
},
|
498 |
+
"metadata": {}
|
499 |
+
},
|
500 |
+
{
|
501 |
+
"output_type": "display_data",
|
502 |
+
"data": {
|
503 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
504 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_212105-q65k78ea</code>"
|
505 |
+
},
|
506 |
+
"metadata": {}
|
507 |
+
},
|
508 |
+
{
|
509 |
+
"output_type": "display_data",
|
510 |
+
"data": {
|
511 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
512 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
513 |
+
},
|
514 |
+
"metadata": {}
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"output_type": "display_data",
|
518 |
+
"data": {
|
519 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
520 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
521 |
+
},
|
522 |
+
"metadata": {}
|
523 |
+
},
|
524 |
+
{
|
525 |
+
"output_type": "display_data",
|
526 |
+
"data": {
|
527 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
528 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea</a>"
|
529 |
+
},
|
530 |
+
"metadata": {}
|
531 |
+
},
|
532 |
+
{
|
533 |
+
"output_type": "execute_result",
|
534 |
+
"execution_count": 15,
|
535 |
+
"data": {
|
536 |
+
"text/html": "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/q65k78ea?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>",
|
537 |
+
"text/plain": "<wandb.sdk.wandb_run.Run at 0x7fea31898d90>"
|
538 |
+
},
|
539 |
+
"metadata": {}
|
540 |
+
}
|
541 |
+
],
|
542 |
+
"execution_count": 15,
|
543 |
+
"metadata": {
|
544 |
+
"gather": {
|
545 |
+
"logged": 1706476872191
|
546 |
+
}
|
547 |
+
}
|
548 |
+
},
|
549 |
+
{
|
550 |
+
"cell_type": "code",
|
551 |
+
"source": [
|
552 |
+
"tracker.start()\n",
|
553 |
+
"trainer.train()\n",
|
554 |
+
"tracker.stop()\n"
|
555 |
+
],
|
556 |
+
"outputs": [
|
557 |
+
{
|
558 |
+
"output_type": "stream",
|
559 |
+
"name": "stderr",
|
560 |
+
"text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"output_type": "display_data",
|
564 |
+
"data": {
|
565 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
566 |
+
"text/html": "\n <div>\n \n <progress value='2907' max='14889' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 2907/14889 31:21 < 2:09:21, 1.54 it/s, Epoch 0.59/3]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
|
567 |
+
},
|
568 |
+
"metadata": {}
|
569 |
+
},
|
570 |
+
{
|
571 |
+
"output_type": "stream",
|
572 |
+
"name": "stderr",
|
573 |
+
"text": "[codecarbon INFO @ 21:21:26] Energy consumed for RAM : 0.000690 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:21:26] Energy consumed for all GPUs : 0.001498 kWh. Total GPU Power : 359.01546838188807 W\n[codecarbon INFO @ 21:21:26] Energy consumed for all CPUs : 0.000177 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:21:26] 0.002365 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:21:41] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:21:41] Energy consumed for all GPUs : 0.004065 kWh. Total GPU Power : 616.9146793770267 W\n[codecarbon INFO @ 21:21:41] Energy consumed for all CPUs : 0.000354 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:21:41] 0.005797 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:21:56] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:21:56] Energy consumed for all GPUs : 0.006654 kWh. Total GPU Power : 621.6877665436252 W\n[codecarbon INFO @ 21:21:56] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:21:56] 0.009251 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:11] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:11] Energy consumed for all GPUs : 0.009260 kWh. Total GPU Power : 626.1437572465749 W\n[codecarbon INFO @ 21:22:11] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:11] 0.012723 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:26] Energy consumed for RAM : 0.003443 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:26] Energy consumed for all GPUs : 0.011865 kWh. Total GPU Power : 625.3693802936192 W\n[codecarbon INFO @ 21:22:26] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:26] 0.016193 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:41] Energy consumed for RAM : 0.004131 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:41] Energy consumed for all GPUs : 0.014488 kWh. Total GPU Power : 630.2419235639226 W\n[codecarbon INFO @ 21:22:41] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:41] 0.019682 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:22:56] Energy consumed for RAM : 0.004819 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:22:56] Energy consumed for all GPUs : 0.017135 kWh. Total GPU Power : 635.8556506868297 W\n[codecarbon INFO @ 21:22:56] Energy consumed for all CPUs : 0.001240 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:22:56] 0.023194 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:11] Energy consumed for RAM : 0.005507 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:11] Energy consumed for all GPUs : 0.019738 kWh. Total GPU Power : 625.0758518089303 W\n[codecarbon INFO @ 21:23:11] Energy consumed for all CPUs : 0.001417 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:11] 0.026662 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:26] Energy consumed for RAM : 0.006195 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:26] Energy consumed for all GPUs : 0.022385 kWh. Total GPU Power : 636.0572579593729 W\n[codecarbon INFO @ 21:23:26] Energy consumed for all CPUs : 0.001594 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:26] 0.030175 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:41] Energy consumed for RAM : 0.006883 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:41] Energy consumed for all GPUs : 0.025031 kWh. Total GPU Power : 635.4132918961806 W\n[codecarbon INFO @ 21:23:41] Energy consumed for all CPUs : 0.001771 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:41] 0.033685 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:23:56] Energy consumed for RAM : 0.007572 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:23:56] Energy consumed for all GPUs : 0.027661 kWh. Total GPU Power : 631.8222916777424 W\n[codecarbon INFO @ 21:23:56] Energy consumed for all CPUs : 0.001948 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:23:56] 0.037180 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:11] Energy consumed for RAM : 0.008260 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:11] Energy consumed for all GPUs : 0.030315 kWh. Total GPU Power : 637.844758085687 W\n[codecarbon INFO @ 21:24:11] Energy consumed for all CPUs : 0.002125 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:11] 0.040701 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:26] Energy consumed for RAM : 0.008948 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:26] Energy consumed for all GPUs : 0.032970 kWh. Total GPU Power : 637.7063667069607 W\n[codecarbon INFO @ 21:24:26] Energy consumed for all CPUs : 0.002302 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:26] 0.044220 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:41] Energy consumed for RAM : 0.009636 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:41] Energy consumed for all GPUs : 0.035629 kWh. Total GPU Power : 638.7595521159491 W\n[codecarbon INFO @ 21:24:41] Energy consumed for all CPUs : 0.002479 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:41] 0.047744 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:24:56] Energy consumed for RAM : 0.010324 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:24:56] Energy consumed for all GPUs : 0.038234 kWh. Total GPU Power : 626.0118880295652 W\n[codecarbon INFO @ 21:24:56] Energy consumed for all CPUs : 0.002657 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:24:56] 0.051214 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:11] Energy consumed for RAM : 0.011012 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:11] Energy consumed for all GPUs : 0.040892 kWh. Total GPU Power : 638.4170631771941 W\n[codecarbon INFO @ 21:25:11] Energy consumed for all CPUs : 0.002834 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:11] 0.054738 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:26] Energy consumed for RAM : 0.011700 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:26] Energy consumed for all GPUs : 0.043524 kWh. Total GPU Power : 632.34394576946 W\n[codecarbon INFO @ 21:25:26] Energy consumed for all CPUs : 0.003011 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:26] 0.058235 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:41] Energy consumed for RAM : 0.012388 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:41] Energy consumed for all GPUs : 0.046182 kWh. Total GPU Power : 638.4662389546352 W\n[codecarbon INFO @ 21:25:41] Energy consumed for all CPUs : 0.003188 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:41] 0.061758 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:25:56] Energy consumed for RAM : 0.013076 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:25:56] Energy consumed for all GPUs : 0.048838 kWh. Total GPU Power : 638.0871021853263 W\n[codecarbon INFO @ 21:25:56] Energy consumed for all CPUs : 0.003365 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:25:56] 0.065279 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:11] Energy consumed for RAM : 0.013765 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:11] Energy consumed for all GPUs : 0.051499 kWh. Total GPU Power : 639.0983849678707 W\n[codecarbon INFO @ 21:26:11] Energy consumed for all CPUs : 0.003542 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:11] 0.068806 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:26] Energy consumed for RAM : 0.014453 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:26] Energy consumed for all GPUs : 0.054132 kWh. Total GPU Power : 632.549674567773 W\n[codecarbon INFO @ 21:26:26] Energy consumed for all CPUs : 0.003719 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:26] 0.072304 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:41] Energy consumed for RAM : 0.015141 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:41] Energy consumed for all GPUs : 0.056768 kWh. Total GPU Power : 633.3096652159345 W\n[codecarbon INFO @ 21:26:41] Energy consumed for all CPUs : 0.003896 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:41] 0.075804 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:26:56] Energy consumed for RAM : 0.015829 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:26:56] Energy consumed for all GPUs : 0.059404 kWh. Total GPU Power : 633.339846576491 W\n[codecarbon INFO @ 21:26:56] Energy consumed for all CPUs : 0.004073 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:26:56] 0.079306 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:11] Energy consumed for RAM : 0.016517 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:11] Energy consumed for all GPUs : 0.062068 kWh. Total GPU Power : 639.9100492849137 W\n[codecarbon INFO @ 21:27:11] Energy consumed for all CPUs : 0.004250 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:11] 0.082835 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:26] Energy consumed for RAM : 0.017205 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:26] Energy consumed for all GPUs : 0.064726 kWh. Total GPU Power : 638.6437092393893 W\n[codecarbon INFO @ 21:27:26] Energy consumed for all CPUs : 0.004427 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:26] 0.086359 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:41] Energy consumed for RAM : 0.017893 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:41] Energy consumed for all GPUs : 0.067388 kWh. Total GPU Power : 639.3487979354586 W\n[codecarbon INFO @ 21:27:41] Energy consumed for all CPUs : 0.004604 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:41] 0.089885 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:27:56] Energy consumed for RAM : 0.018581 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:27:56] Energy consumed for all GPUs : 0.070026 kWh. Total GPU Power : 633.6884387646057 W\n[codecarbon INFO @ 21:27:56] Energy consumed for all CPUs : 0.004781 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:27:56] 0.093389 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:11] Energy consumed for RAM : 0.019269 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:11] Energy consumed for all GPUs : 0.072687 kWh. Total GPU Power : 639.4422525221754 W\n[codecarbon INFO @ 21:28:11] Energy consumed for all CPUs : 0.004958 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:11] 0.096915 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:26] Energy consumed for RAM : 0.019958 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:26] Energy consumed for all GPUs : 0.075296 kWh. Total GPU Power : 626.9464464111006 W\n[codecarbon INFO @ 21:28:26] Energy consumed for all CPUs : 0.005135 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:26] 0.100390 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:41] Energy consumed for RAM : 0.020646 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:41] Energy consumed for all GPUs : 0.077962 kWh. Total GPU Power : 640.3962575270206 W\n[codecarbon INFO @ 21:28:41] Energy consumed for all CPUs : 0.005313 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:41] 0.103921 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:28:56] Energy consumed for RAM : 0.021334 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:28:56] Energy consumed for all GPUs : 0.080619 kWh. Total GPU Power : 638.3087387539953 W\n[codecarbon INFO @ 21:28:56] Energy consumed for all CPUs : 0.005490 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:28:56] 0.107443 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:11] Energy consumed for RAM : 0.022022 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:11] Energy consumed for all GPUs : 0.083270 kWh. Total GPU Power : 636.8708359764104 W\n[codecarbon INFO @ 21:29:11] Energy consumed for all CPUs : 0.005667 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:11] 0.110959 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:26] Energy consumed for RAM : 0.022710 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:26] Energy consumed for all GPUs : 0.085893 kWh. Total GPU Power : 630.0796388169725 W\n[codecarbon INFO @ 21:29:26] Energy consumed for all CPUs : 0.005844 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:26] 0.114447 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:41] Energy consumed for RAM : 0.023398 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:41] Energy consumed for all GPUs : 0.088548 kWh. Total GPU Power : 637.7758378447022 W\n[codecarbon INFO @ 21:29:41] Energy consumed for all CPUs : 0.006021 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:41] 0.117968 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:29:56] Energy consumed for RAM : 0.024087 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:29:56] Energy consumed for all GPUs : 0.091193 kWh. Total GPU Power : 634.8550146720521 W\n[codecarbon INFO @ 21:29:56] Energy consumed for all CPUs : 0.006198 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:29:56] 0.121478 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:11] Energy consumed for RAM : 0.024775 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:11] Energy consumed for all GPUs : 0.093817 kWh. Total GPU Power : 630.5186457226341 W\n[codecarbon INFO @ 21:30:11] Energy consumed for all CPUs : 0.006375 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:11] 0.124967 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:26] Energy consumed for RAM : 0.025463 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:26] Energy consumed for all GPUs : 0.096471 kWh. Total GPU Power : 637.5849420686613 W\n[codecarbon INFO @ 21:30:26] Energy consumed for all CPUs : 0.006552 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:26] 0.128486 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:41] Energy consumed for RAM : 0.026151 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:41] Energy consumed for all GPUs : 0.099109 kWh. Total GPU Power : 633.6189362439791 W\n[codecarbon INFO @ 21:30:41] Energy consumed for all CPUs : 0.006729 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:41] 0.131990 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:30:56] Energy consumed for RAM : 0.026839 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:30:56] Energy consumed for all GPUs : 0.101776 kWh. Total GPU Power : 640.6257944471723 W\n[codecarbon INFO @ 21:30:56] Energy consumed for all CPUs : 0.006906 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:30:56] 0.135522 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:11] Energy consumed for RAM : 0.027528 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:11] Energy consumed for all GPUs : 0.104439 kWh. Total GPU Power : 639.6643020904513 W\n[codecarbon INFO @ 21:31:11] Energy consumed for all CPUs : 0.007083 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:11] 0.139050 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:26] Energy consumed for RAM : 0.028216 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:26] Energy consumed for all GPUs : 0.107104 kWh. Total GPU Power : 640.0939172348444 W\n[codecarbon INFO @ 21:31:26] Energy consumed for all CPUs : 0.007260 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:26] 0.142580 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:41] Energy consumed for RAM : 0.028904 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:41] Energy consumed for all GPUs : 0.109747 kWh. Total GPU Power : 634.9935430223439 W\n[codecarbon INFO @ 21:31:41] Energy consumed for all CPUs : 0.007438 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:41] 0.146088 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:31:56] Energy consumed for RAM : 0.029592 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:31:56] Energy consumed for all GPUs : 0.112378 kWh. Total GPU Power : 632.1091666419065 W\n[codecarbon INFO @ 21:31:56] Energy consumed for all CPUs : 0.007615 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:31:56] 0.149585 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:11] Energy consumed for RAM : 0.030281 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:11] Energy consumed for all GPUs : 0.115023 kWh. Total GPU Power : 633.8442942682154 W\n[codecarbon INFO @ 21:32:11] Energy consumed for all CPUs : 0.007792 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:11] 0.153096 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:26] Energy consumed for RAM : 0.030968 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:26] Energy consumed for all GPUs : 0.117663 kWh. Total GPU Power : 635.0998743598051 W\n[codecarbon INFO @ 21:32:26] Energy consumed for all CPUs : 0.007969 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:26] 0.156599 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:41] Energy consumed for RAM : 0.031656 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:41] Energy consumed for all GPUs : 0.120332 kWh. Total GPU Power : 641.1040333084177 W\n[codecarbon INFO @ 21:32:41] Energy consumed for all CPUs : 0.008146 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:41] 0.160134 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:32:56] Energy consumed for RAM : 0.032344 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:32:56] Energy consumed for all GPUs : 0.122995 kWh. Total GPU Power : 639.7462391288783 W\n[codecarbon INFO @ 21:32:56] Energy consumed for all CPUs : 0.008323 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:32:56] 0.163662 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:11] Energy consumed for RAM : 0.033033 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:11] Energy consumed for all GPUs : 0.125648 kWh. Total GPU Power : 637.3370633888644 W\n[codecarbon INFO @ 21:33:11] Energy consumed for all CPUs : 0.008500 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:11] 0.167180 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:26] Energy consumed for RAM : 0.033721 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:26] Energy consumed for all GPUs : 0.128260 kWh. Total GPU Power : 627.497349520734 W\n[codecarbon INFO @ 21:33:26] Energy consumed for all CPUs : 0.008677 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:26] 0.170658 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:41] Energy consumed for RAM : 0.034409 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:41] Energy consumed for all GPUs : 0.130922 kWh. Total GPU Power : 639.378459827986 W\n[codecarbon INFO @ 21:33:41] Energy consumed for all CPUs : 0.008854 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:41] 0.174185 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:33:56] Energy consumed for RAM : 0.035097 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:33:56] Energy consumed for all GPUs : 0.133563 kWh. Total GPU Power : 634.2779963263187 W\n[codecarbon INFO @ 21:33:56] Energy consumed for all CPUs : 0.009031 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:33:56] 0.177692 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:11] Energy consumed for RAM : 0.035785 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:11] Energy consumed for all GPUs : 0.136211 kWh. Total GPU Power : 636.088462236655 W\n[codecarbon INFO @ 21:34:11] Energy consumed for all CPUs : 0.009208 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:11] 0.181205 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:26] Energy consumed for RAM : 0.036474 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:26] Energy consumed for all GPUs : 0.138875 kWh. Total GPU Power : 639.8420566736949 W\n[codecarbon INFO @ 21:34:26] Energy consumed for all CPUs : 0.009385 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:26] 0.184734 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:41] Energy consumed for RAM : 0.037162 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:41] Energy consumed for all GPUs : 0.141537 kWh. Total GPU Power : 639.4628459940732 W\n[codecarbon INFO @ 21:34:41] Energy consumed for all CPUs : 0.009563 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:41] 0.188261 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:34:56] Energy consumed for RAM : 0.037850 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:34:56] Energy consumed for all GPUs : 0.144193 kWh. Total GPU Power : 638.284138549091 W\n[codecarbon INFO @ 21:34:56] Energy consumed for all CPUs : 0.009740 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:34:56] 0.191782 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:11] Energy consumed for RAM : 0.038538 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:11] Energy consumed for all GPUs : 0.146807 kWh. Total GPU Power : 627.9721129851367 W\n[codecarbon INFO @ 21:35:11] Energy consumed for all CPUs : 0.009917 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:11] 0.195262 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:26] Energy consumed for RAM : 0.039226 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:26] Energy consumed for all GPUs : 0.149465 kWh. Total GPU Power : 638.5284782703005 W\n[codecarbon INFO @ 21:35:26] Energy consumed for all CPUs : 0.010094 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:26] 0.198785 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:41] Energy consumed for RAM : 0.039914 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:41] Energy consumed for all GPUs : 0.152101 kWh. Total GPU Power : 633.1180716439897 W\n[codecarbon INFO @ 21:35:41] Energy consumed for all CPUs : 0.010271 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:41] 0.202286 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:35:56] Energy consumed for RAM : 0.040602 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:35:56] Energy consumed for all GPUs : 0.154764 kWh. Total GPU Power : 640.0670545574203 W\n[codecarbon INFO @ 21:35:56] Energy consumed for all CPUs : 0.010448 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:35:56] 0.205814 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:11] Energy consumed for RAM : 0.041290 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:11] Energy consumed for all GPUs : 0.157432 kWh. Total GPU Power : 640.6751187111053 W\n[codecarbon INFO @ 21:36:11] Energy consumed for all CPUs : 0.010625 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:11] 0.209347 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:26] Energy consumed for RAM : 0.041979 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:26] Energy consumed for all GPUs : 0.160090 kWh. Total GPU Power : 638.5720494734854 W\n[codecarbon INFO @ 21:36:26] Energy consumed for all CPUs : 0.010802 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:26] 0.212871 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:41] Energy consumed for RAM : 0.042667 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:41] Energy consumed for all GPUs : 0.162737 kWh. Total GPU Power : 635.8084991485674 W\n[codecarbon INFO @ 21:36:41] Energy consumed for all CPUs : 0.010979 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:41] 0.216383 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:36:56] Energy consumed for RAM : 0.043355 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:36:56] Energy consumed for all GPUs : 0.165376 kWh. Total GPU Power : 634.1987011824029 W\n[codecarbon INFO @ 21:36:56] Energy consumed for all CPUs : 0.011156 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:36:56] 0.219886 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:11] Energy consumed for RAM : 0.044043 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:11] Energy consumed for all GPUs : 0.168015 kWh. Total GPU Power : 633.9887371766706 W\n[codecarbon INFO @ 21:37:11] Energy consumed for all CPUs : 0.011333 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:11] 0.223391 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:26] Energy consumed for RAM : 0.044731 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:26] Energy consumed for all GPUs : 0.170672 kWh. Total GPU Power : 638.6399487093975 W\n[codecarbon INFO @ 21:37:26] Energy consumed for all CPUs : 0.011510 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:26] 0.226913 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:41] Energy consumed for RAM : 0.045419 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:41] Energy consumed for all GPUs : 0.173330 kWh. Total GPU Power : 638.3717543169629 W\n[codecarbon INFO @ 21:37:41] Energy consumed for all CPUs : 0.011687 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:41] 0.230436 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:37:56] Energy consumed for RAM : 0.046107 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:37:56] Energy consumed for all GPUs : 0.175996 kWh. Total GPU Power : 640.4666251525215 W\n[codecarbon INFO @ 21:37:56] Energy consumed for all CPUs : 0.011864 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:37:56] 0.233967 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:11] Energy consumed for RAM : 0.046795 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:11] Energy consumed for all GPUs : 0.178655 kWh. Total GPU Power : 638.5808546734917 W\n[codecarbon INFO @ 21:38:11] Energy consumed for all CPUs : 0.012042 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:11] 0.237491 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:26] Energy consumed for RAM : 0.047483 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:26] Energy consumed for all GPUs : 0.181299 kWh. Total GPU Power : 635.2413118760992 W\n[codecarbon INFO @ 21:38:26] Energy consumed for all CPUs : 0.012219 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:26] 0.241001 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:41] Energy consumed for RAM : 0.048172 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:41] Energy consumed for all GPUs : 0.183936 kWh. Total GPU Power : 633.4649546198842 W\n[codecarbon INFO @ 21:38:41] Energy consumed for all CPUs : 0.012396 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:41] 0.244503 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:38:56] Energy consumed for RAM : 0.048860 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:38:56] Energy consumed for all GPUs : 0.186584 kWh. Total GPU Power : 636.131524414156 W\n[codecarbon INFO @ 21:38:56] Energy consumed for all CPUs : 0.012573 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:38:56] 0.248017 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:11] Energy consumed for RAM : 0.049548 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:11] Energy consumed for all GPUs : 0.189253 kWh. Total GPU Power : 641.2634282249857 W\n[codecarbon INFO @ 21:39:11] Energy consumed for all CPUs : 0.012750 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:11] 0.251551 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:26] Energy consumed for RAM : 0.050236 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:26] Energy consumed for all GPUs : 0.191926 kWh. Total GPU Power : 642.0441434380534 W\n[codecarbon INFO @ 21:39:26] Energy consumed for all CPUs : 0.012927 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:26] 0.255089 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:41] Energy consumed for RAM : 0.050924 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:41] Energy consumed for all GPUs : 0.194582 kWh. Total GPU Power : 637.9781331087586 W\n[codecarbon INFO @ 21:39:41] Energy consumed for all CPUs : 0.013104 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:41] 0.258610 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:39:56] Energy consumed for RAM : 0.051612 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:39:56] Energy consumed for all GPUs : 0.197230 kWh. Total GPU Power : 636.2697706727595 W\n[codecarbon INFO @ 21:39:56] Energy consumed for all CPUs : 0.013281 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:39:56] 0.262124 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:11] Energy consumed for RAM : 0.052300 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:11] Energy consumed for all GPUs : 0.199895 kWh. Total GPU Power : 640.3428775768339 W\n[codecarbon INFO @ 21:40:11] Energy consumed for all CPUs : 0.013458 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:11] 0.265654 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:26] Energy consumed for RAM : 0.052988 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:26] Energy consumed for all GPUs : 0.202515 kWh. Total GPU Power : 629.2766685535789 W\n[codecarbon INFO @ 21:40:26] Energy consumed for all CPUs : 0.013635 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:26] 0.269139 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:41] Energy consumed for RAM : 0.053677 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:41] Energy consumed for all GPUs : 0.205189 kWh. Total GPU Power : 642.1835694357527 W\n[codecarbon INFO @ 21:40:41] Energy consumed for all CPUs : 0.013812 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:41] 0.272678 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:40:56] Energy consumed for RAM : 0.054365 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:40:56] Energy consumed for all GPUs : 0.207849 kWh. Total GPU Power : 639.0512123347582 W\n[codecarbon INFO @ 21:40:56] Energy consumed for all CPUs : 0.013989 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:40:56] 0.276203 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:11] Energy consumed for RAM : 0.055053 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:11] Energy consumed for all GPUs : 0.210510 kWh. Total GPU Power : 639.2114833450486 W\n[codecarbon INFO @ 21:41:11] Energy consumed for all CPUs : 0.014166 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:11] 0.279729 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:26] Energy consumed for RAM : 0.055741 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:26] Energy consumed for all GPUs : 0.213150 kWh. Total GPU Power : 634.2127304940207 W\n[codecarbon INFO @ 21:41:26] Energy consumed for all CPUs : 0.014343 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:26] 0.283234 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:41] Energy consumed for RAM : 0.056429 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:41] Energy consumed for all GPUs : 0.215819 kWh. Total GPU Power : 641.2007745785662 W\n[codecarbon INFO @ 21:41:41] Energy consumed for all CPUs : 0.014521 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:41] 0.286769 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:41:56] Energy consumed for RAM : 0.057117 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:41:56] Energy consumed for all GPUs : 0.218486 kWh. Total GPU Power : 640.7227371900796 W\n[codecarbon INFO @ 21:41:56] Energy consumed for all CPUs : 0.014698 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:41:56] 0.290301 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:11] Energy consumed for RAM : 0.057806 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:11] Energy consumed for all GPUs : 0.221103 kWh. Total GPU Power : 628.6339716630874 W\n[codecarbon INFO @ 21:42:11] Energy consumed for all CPUs : 0.014875 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:11] 0.293783 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:26] Energy consumed for RAM : 0.058494 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:26] Energy consumed for all GPUs : 0.223767 kWh. Total GPU Power : 640.0746364068535 W\n[codecarbon INFO @ 21:42:26] Energy consumed for all CPUs : 0.015052 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:26] 0.297313 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:41] Energy consumed for RAM : 0.059182 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:41] Energy consumed for all GPUs : 0.226424 kWh. Total GPU Power : 638.2080824809826 W\n[codecarbon INFO @ 21:42:41] Energy consumed for all CPUs : 0.015229 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:41] 0.300835 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:42:56] Energy consumed for RAM : 0.059870 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:42:56] Energy consumed for all GPUs : 0.229064 kWh. Total GPU Power : 634.2533218929578 W\n[codecarbon INFO @ 21:42:56] Energy consumed for all CPUs : 0.015406 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:42:56] 0.304340 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:11] Energy consumed for RAM : 0.060558 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:11] Energy consumed for all GPUs : 0.231699 kWh. Total GPU Power : 632.8743496381484 W\n[codecarbon INFO @ 21:43:11] Energy consumed for all CPUs : 0.015583 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:11] 0.307840 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:26] Energy consumed for RAM : 0.061246 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:26] Energy consumed for all GPUs : 0.234369 kWh. Total GPU Power : 641.6283834036132 W\n[codecarbon INFO @ 21:43:26] Energy consumed for all CPUs : 0.015760 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:26] 0.311375 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:41] Energy consumed for RAM : 0.061934 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:41] Energy consumed for all GPUs : 0.237043 kWh. Total GPU Power : 642.3040649089497 W\n[codecarbon INFO @ 21:43:41] Energy consumed for all CPUs : 0.015937 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:41] 0.314915 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:43:56] Energy consumed for RAM : 0.062622 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:43:56] Energy consumed for all GPUs : 0.239655 kWh. Total GPU Power : 627.6126294000912 W\n[codecarbon INFO @ 21:43:56] Energy consumed for all CPUs : 0.016114 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:43:56] 0.318392 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:11] Energy consumed for RAM : 0.063311 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:11] Energy consumed for all GPUs : 0.242321 kWh. Total GPU Power : 640.4353576994267 W\n[codecarbon INFO @ 21:44:11] Energy consumed for all CPUs : 0.016291 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:11] 0.321923 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:26] Energy consumed for RAM : 0.063999 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:26] Energy consumed for all GPUs : 0.244988 kWh. Total GPU Power : 640.5277401792778 W\n[codecarbon INFO @ 21:44:26] Energy consumed for all CPUs : 0.016468 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:26] 0.325455 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:41] Energy consumed for RAM : 0.064687 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:41] Energy consumed for all GPUs : 0.247629 kWh. Total GPU Power : 634.3819968519699 W\n[codecarbon INFO @ 21:44:41] Energy consumed for all CPUs : 0.016645 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:41] 0.328961 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:44:56] Energy consumed for RAM : 0.065375 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:44:56] Energy consumed for all GPUs : 0.250281 kWh. Total GPU Power : 637.265582807383 W\n[codecarbon INFO @ 21:44:56] Energy consumed for all CPUs : 0.016823 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:44:56] 0.332479 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:11] Energy consumed for RAM : 0.066063 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:11] Energy consumed for all GPUs : 0.252945 kWh. Total GPU Power : 639.7317988572786 W\n[codecarbon INFO @ 21:45:11] Energy consumed for all CPUs : 0.017000 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:11] 0.336008 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:26] Energy consumed for RAM : 0.066751 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:26] Energy consumed for all GPUs : 0.255612 kWh. Total GPU Power : 640.7293994008817 W\n[codecarbon INFO @ 21:45:26] Energy consumed for all CPUs : 0.017177 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:26] 0.339540 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:41] Energy consumed for RAM : 0.067439 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:41] Energy consumed for all GPUs : 0.258225 kWh. Total GPU Power : 627.831662994067 W\n[codecarbon INFO @ 21:45:41] Energy consumed for all CPUs : 0.017354 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:41] 0.343018 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:45:56] Energy consumed for RAM : 0.068128 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:45:56] Energy consumed for all GPUs : 0.260886 kWh. Total GPU Power : 639.0834373322126 W\n[codecarbon INFO @ 21:45:56] Energy consumed for all CPUs : 0.017531 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:45:56] 0.346544 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:11] Energy consumed for RAM : 0.068816 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:11] Energy consumed for all GPUs : 0.263551 kWh. Total GPU Power : 640.21811942804 W\n[codecarbon INFO @ 21:46:11] Energy consumed for all CPUs : 0.017708 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:11] 0.350075 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:26] Energy consumed for RAM : 0.069504 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:26] Energy consumed for all GPUs : 0.266199 kWh. Total GPU Power : 635.36554464275 W\n[codecarbon INFO @ 21:46:26] Energy consumed for all CPUs : 0.017885 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:26] 0.353588 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:41] Energy consumed for RAM : 0.070192 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:41] Energy consumed for all GPUs : 0.268846 kWh. Total GPU Power : 636.3615954525276 W\n[codecarbon INFO @ 21:46:41] Energy consumed for all CPUs : 0.018062 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:41] 0.357099 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:46:56] Energy consumed for RAM : 0.070880 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:46:56] Energy consumed for all GPUs : 0.271505 kWh. Total GPU Power : 638.7791497333527 W\n[codecarbon INFO @ 21:46:56] Energy consumed for all CPUs : 0.018239 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:46:56] 0.360624 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:11] Energy consumed for RAM : 0.071568 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:11] Energy consumed for all GPUs : 0.274180 kWh. Total GPU Power : 642.4466497637459 W\n[codecarbon INFO @ 21:47:11] Energy consumed for all CPUs : 0.018416 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:11] 0.364164 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:26] Energy consumed for RAM : 0.072256 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:26] Energy consumed for all GPUs : 0.276794 kWh. Total GPU Power : 628.3190508121727 W\n[codecarbon INFO @ 21:47:26] Energy consumed for all CPUs : 0.018593 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:26] 0.367644 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:41] Energy consumed for RAM : 0.072944 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:41] Energy consumed for all GPUs : 0.279466 kWh. Total GPU Power : 641.8670593361426 W\n[codecarbon INFO @ 21:47:41] Energy consumed for all CPUs : 0.018770 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:41] 0.371181 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:47:56] Energy consumed for RAM : 0.073632 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:47:56] Energy consumed for all GPUs : 0.282137 kWh. Total GPU Power : 641.4703729656464 W\n[codecarbon INFO @ 21:47:56] Energy consumed for all CPUs : 0.018947 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:47:56] 0.374716 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:11] Energy consumed for RAM : 0.074321 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:11] Energy consumed for all GPUs : 0.284781 kWh. Total GPU Power : 635.1853195157872 W\n[codecarbon INFO @ 21:48:11] Energy consumed for all CPUs : 0.019125 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:11] 0.378226 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:26] Energy consumed for RAM : 0.075008 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:26] Energy consumed for all GPUs : 0.287442 kWh. Total GPU Power : 639.5462106369428 W\n[codecarbon INFO @ 21:48:26] Energy consumed for all CPUs : 0.019302 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:26] 0.381752 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:41] Energy consumed for RAM : 0.075697 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:41] Energy consumed for all GPUs : 0.290110 kWh. Total GPU Power : 640.8566755972051 W\n[codecarbon INFO @ 21:48:41] Energy consumed for all CPUs : 0.019479 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:41] 0.385285 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:48:56] Energy consumed for RAM : 0.076385 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:48:56] Energy consumed for all GPUs : 0.292782 kWh. Total GPU Power : 641.7052864666981 W\n[codecarbon INFO @ 21:48:56] Energy consumed for all CPUs : 0.019656 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:48:56] 0.388822 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:11] Energy consumed for RAM : 0.077073 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:11] Energy consumed for all GPUs : 0.295396 kWh. Total GPU Power : 628.327852773132 W\n[codecarbon INFO @ 21:49:11] Energy consumed for all CPUs : 0.019833 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:11] 0.392302 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:26] Energy consumed for RAM : 0.077761 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:26] Energy consumed for all GPUs : 0.298062 kWh. Total GPU Power : 640.3407136160951 W\n[codecarbon INFO @ 21:49:26] Energy consumed for all CPUs : 0.020010 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:26] 0.395833 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:41] Energy consumed for RAM : 0.078448 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:41] Energy consumed for all GPUs : 0.300704 kWh. Total GPU Power : 635.323287223873 W\n[codecarbon INFO @ 21:49:41] Energy consumed for all CPUs : 0.020187 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:41] 0.399339 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:49:56] Energy consumed for RAM : 0.079137 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:49:56] Energy consumed for all GPUs : 0.303364 kWh. Total GPU Power : 638.8823021398234 W\n[codecarbon INFO @ 21:49:56] Energy consumed for all CPUs : 0.020364 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:49:56] 0.402864 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:11] Energy consumed for RAM : 0.079825 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:11] Energy consumed for all GPUs : 0.306026 kWh. Total GPU Power : 639.5449269632837 W\n[codecarbon INFO @ 21:50:11] Energy consumed for all CPUs : 0.020541 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:11] 0.406392 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:26] Energy consumed for RAM : 0.080513 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:26] Energy consumed for all GPUs : 0.308696 kWh. Total GPU Power : 641.2978321880182 W\n[codecarbon INFO @ 21:50:26] Energy consumed for all CPUs : 0.020718 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:26] 0.409927 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:41] Energy consumed for RAM : 0.081201 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:41] Energy consumed for all GPUs : 0.311344 kWh. Total GPU Power : 636.0606478210286 W\n[codecarbon INFO @ 21:50:41] Energy consumed for all CPUs : 0.020895 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:41] 0.413440 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:50:56] Energy consumed for RAM : 0.081889 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:50:56] Energy consumed for all GPUs : 0.313981 kWh. Total GPU Power : 633.4763683114048 W\n[codecarbon INFO @ 21:50:56] Energy consumed for all CPUs : 0.021072 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:50:56] 0.416942 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:11] Energy consumed for RAM : 0.082577 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:11] Energy consumed for all GPUs : 0.316621 kWh. Total GPU Power : 634.3093356260223 W\n[codecarbon INFO @ 21:51:11] Energy consumed for all CPUs : 0.021249 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:11] 0.420447 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:26] Energy consumed for RAM : 0.083265 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:26] Energy consumed for all GPUs : 0.319275 kWh. Total GPU Power : 637.9102256749104 W\n[codecarbon INFO @ 21:51:26] Energy consumed for all CPUs : 0.021426 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:26] 0.423967 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:41] Energy consumed for RAM : 0.083953 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:41] Energy consumed for all GPUs : 0.321942 kWh. Total GPU Power : 640.7082551761707 W\n[codecarbon INFO @ 21:51:41] Energy consumed for all CPUs : 0.021603 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:41] 0.427499 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:51:56] Energy consumed for RAM : 0.084642 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:51:56] Energy consumed for all GPUs : 0.324609 kWh. Total GPU Power : 640.5186958631616 W\n[codecarbon INFO @ 21:51:56] Energy consumed for all CPUs : 0.021781 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:51:56] 0.431032 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:52:11] Energy consumed for RAM : 0.085330 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:52:11] Energy consumed for all GPUs : 0.327266 kWh. Total GPU Power : 638.3529279530002 W\n[codecarbon INFO @ 21:52:11] Energy consumed for all CPUs : 0.021958 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:52:11] 0.434553 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:52:26] Energy consumed for RAM : 0.086018 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:52:26] Energy consumed for all GPUs : 0.329910 kWh. Total GPU Power : 635.3762025182446 W\n[codecarbon INFO @ 21:52:26] Energy consumed for all CPUs : 0.022135 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:52:26] 0.438062 kWh of electricity used since the beginning.\n[codecarbon INFO @ 21:52:41] Energy consumed for RAM : 0.086706 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 21:52:41] Energy consumed for all GPUs : 0.332545 kWh. Total GPU Power : 632.9484218502305 W\n[codecarbon INFO @ 21:52:41] Energy consumed for all CPUs : 0.022312 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 21:52:41] 0.441563 kWh of electricity used since the beginning.\n"
|
574 |
+
}
|
575 |
+
],
|
576 |
+
"execution_count": 16,
|
577 |
+
"metadata": {
|
578 |
+
"gather": {
|
579 |
+
"logged": 1706476228988
|
580 |
+
}
|
581 |
+
}
|
582 |
+
},
|
583 |
+
{
|
584 |
+
"cell_type": "code",
|
585 |
+
"source": [
|
586 |
+
"wandb.finish()"
|
587 |
+
],
|
588 |
+
"outputs": [],
|
589 |
+
"execution_count": null,
|
590 |
+
"metadata": {
|
591 |
+
"gather": {
|
592 |
+
"logged": 1706476229030
|
593 |
+
}
|
594 |
+
}
|
595 |
+
},
|
596 |
+
{
|
597 |
+
"cell_type": "code",
|
598 |
+
"source": [
|
599 |
+
"variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
600 |
+
"tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
601 |
+
"tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
602 |
+
"sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
603 |
+
"\n",
|
604 |
+
"model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
605 |
+
" variant=variant,\n",
|
606 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
|
607 |
+
],
|
608 |
+
"outputs": [],
|
609 |
+
"execution_count": null,
|
610 |
+
"metadata": {
|
611 |
+
"gather": {
|
612 |
+
"logged": 1706476229038
|
613 |
+
}
|
614 |
+
}
|
615 |
+
}
|
616 |
+
],
|
617 |
+
"metadata": {
|
618 |
+
"datalore": {
|
619 |
+
"base_environment": "default",
|
620 |
+
"computation_mode": "JUPYTER",
|
621 |
+
"package_manager": "pip",
|
622 |
+
"packages": [
|
623 |
+
{
|
624 |
+
"name": "datasets",
|
625 |
+
"source": "PIP",
|
626 |
+
"version": "2.16.1"
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"name": "torch",
|
630 |
+
"source": "PIP",
|
631 |
+
"version": "2.1.2"
|
632 |
+
},
|
633 |
+
{
|
634 |
+
"name": "accelerate",
|
635 |
+
"source": "PIP",
|
636 |
+
"version": "0.26.1"
|
637 |
+
}
|
638 |
+
],
|
639 |
+
"report_row_ids": [
|
640 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
641 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
642 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
643 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
644 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
645 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
646 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
647 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
648 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
649 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
650 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
651 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
652 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
653 |
+
],
|
654 |
+
"version": 3
|
655 |
+
},
|
656 |
+
"kernelspec": {
|
657 |
+
"name": "python38-azureml-pt-tf",
|
658 |
+
"language": "python",
|
659 |
+
"display_name": "Python 3.8 - Pytorch and Tensorflow"
|
660 |
+
},
|
661 |
+
"language_info": {
|
662 |
+
"name": "python",
|
663 |
+
"version": "3.8.5",
|
664 |
+
"mimetype": "text/x-python",
|
665 |
+
"codemirror_mode": {
|
666 |
+
"name": "ipython",
|
667 |
+
"version": 3
|
668 |
+
},
|
669 |
+
"pygments_lexer": "ipython3",
|
670 |
+
"nbconvert_exporter": "python",
|
671 |
+
"file_extension": ".py"
|
672 |
+
},
|
673 |
+
"microsoft": {
|
674 |
+
"host": {
|
675 |
+
"AzureML": {
|
676 |
+
"notebookHasBeenCompleted": true
|
677 |
+
}
|
678 |
+
},
|
679 |
+
"ms_spell_check": {
|
680 |
+
"ms_spell_check_language": "en"
|
681 |
+
}
|
682 |
+
},
|
683 |
+
"nteract": {
|
684 |
+
"version": "nteract-front-end@1.0.0"
|
685 |
+
},
|
686 |
+
"kernel_info": {
|
687 |
+
"name": "python38-azureml-pt-tf"
|
688 |
+
}
|
689 |
+
},
|
690 |
+
"nbformat": 4,
|
691 |
+
"nbformat_minor": 4
|
692 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-3-12-1Z.ipynb
ADDED
@@ -0,0 +1,1053 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {
|
11 |
+
"collapsed": false
|
12 |
+
}
|
13 |
+
},
|
14 |
+
{
|
15 |
+
"cell_type": "code",
|
16 |
+
"source": [
|
17 |
+
"%pip install accelerate -U"
|
18 |
+
],
|
19 |
+
"outputs": [
|
20 |
+
{
|
21 |
+
"output_type": "stream",
|
22 |
+
"name": "stdout",
|
23 |
+
"text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nNote: you may need to restart the kernel to use updated packages.\n"
|
24 |
+
}
|
25 |
+
],
|
26 |
+
"execution_count": 1,
|
27 |
+
"metadata": {
|
28 |
+
"jupyter": {
|
29 |
+
"source_hidden": false,
|
30 |
+
"outputs_hidden": false
|
31 |
+
},
|
32 |
+
"nteract": {
|
33 |
+
"transient": {
|
34 |
+
"deleting": false
|
35 |
+
}
|
36 |
+
}
|
37 |
+
}
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"source": [
|
42 |
+
"%pip install transformers datasets shap watermark wandb"
|
43 |
+
],
|
44 |
+
"outputs": [
|
45 |
+
{
|
46 |
+
"output_type": "stream",
|
47 |
+
"name": "stderr",
|
48 |
+
"text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
49 |
+
},
|
50 |
+
{
|
51 |
+
"output_type": "stream",
|
52 |
+
"name": "stdout",
|
53 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nCollecting wandb\n Using cached wandb-0.16.2-py3-none-any.whl (2.2 MB)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nCollecting sentry-sdk>=1.0.0\n Using cached sentry_sdk-1.39.2-py2.py3-none-any.whl (254 kB)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nCollecting docker-pycreds>=0.4.0\n Using cached docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nCollecting setproctitle\n Using cached setproctitle-1.3.3-cp38-cp38-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (31 kB)\nCollecting appdirs>=1.4.3\n Using cached appdirs-1.4.4-py2.py3-none-any.whl (9.6 kB)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nInstalling collected packages: appdirs, setproctitle, sentry-sdk, docker-pycreds, wandb\nSuccessfully installed appdirs-1.4.4 docker-pycreds-0.4.0 sentry-sdk-1.39.2 setproctitle-1.3.3 wandb-0.16.2\nNote: you may need to restart the kernel to use updated packages.\n"
|
54 |
+
}
|
55 |
+
],
|
56 |
+
"execution_count": 17,
|
57 |
+
"metadata": {
|
58 |
+
"jupyter": {
|
59 |
+
"source_hidden": false,
|
60 |
+
"outputs_hidden": false
|
61 |
+
},
|
62 |
+
"nteract": {
|
63 |
+
"transient": {
|
64 |
+
"deleting": false
|
65 |
+
}
|
66 |
+
}
|
67 |
+
}
|
68 |
+
},
|
69 |
+
{
|
70 |
+
"cell_type": "code",
|
71 |
+
"source": [
|
72 |
+
"import pandas as pd\n",
|
73 |
+
"import numpy as np\n",
|
74 |
+
"import torch\n",
|
75 |
+
"import os\n",
|
76 |
+
"from typing import List\n",
|
77 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
78 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
|
79 |
+
"from datasets import load_dataset\n",
|
80 |
+
"import shap\n",
|
81 |
+
"\n",
|
82 |
+
"%load_ext watermark"
|
83 |
+
],
|
84 |
+
"outputs": [
|
85 |
+
{
|
86 |
+
"output_type": "stream",
|
87 |
+
"name": "stderr",
|
88 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-28 02:27:28.730200: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-28 02:27:29.708865: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-28 02:27:29.708983: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-28 02:27:29.708996: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
89 |
+
}
|
90 |
+
],
|
91 |
+
"execution_count": 3,
|
92 |
+
"metadata": {
|
93 |
+
"datalore": {
|
94 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
95 |
+
"type": "CODE",
|
96 |
+
"hide_input_from_viewers": false,
|
97 |
+
"hide_output_from_viewers": false,
|
98 |
+
"report_properties": {
|
99 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
100 |
+
}
|
101 |
+
},
|
102 |
+
"gather": {
|
103 |
+
"logged": 1706408851775
|
104 |
+
}
|
105 |
+
}
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"source": [
|
110 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
111 |
+
"\n",
|
112 |
+
"SEED: int = 42\n",
|
113 |
+
"\n",
|
114 |
+
"BATCH_SIZE: int = 8\n",
|
115 |
+
"EPOCHS: int = 1\n",
|
116 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
117 |
+
"\n",
|
118 |
+
"CLASS_NAMES: List[str] = [\"DIED\",\n",
|
119 |
+
" \"ER_VISIT\",\n",
|
120 |
+
" \"HOSPITAL\",\n",
|
121 |
+
" \"OFC_VISIT\",\n",
|
122 |
+
" \"X_STAY\",\n",
|
123 |
+
" \"DISABLE\",\n",
|
124 |
+
" \"D_PRESENTED\"]\n",
|
125 |
+
"\n",
|
126 |
+
"# WandB configuration\n",
|
127 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
|
128 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints"
|
129 |
+
],
|
130 |
+
"outputs": [],
|
131 |
+
"execution_count": 4,
|
132 |
+
"metadata": {
|
133 |
+
"collapsed": false,
|
134 |
+
"gather": {
|
135 |
+
"logged": 1706408852045
|
136 |
+
}
|
137 |
+
}
|
138 |
+
},
|
139 |
+
{
|
140 |
+
"cell_type": "code",
|
141 |
+
"source": [
|
142 |
+
"%watermark --iversion"
|
143 |
+
],
|
144 |
+
"outputs": [
|
145 |
+
{
|
146 |
+
"output_type": "stream",
|
147 |
+
"name": "stdout",
|
148 |
+
"text": "re : 2.2.1\nnumpy : 1.23.5\nlogging: 0.5.1.2\npandas : 2.0.2\ntorch : 1.12.0\nshap : 0.44.1\n\n"
|
149 |
+
}
|
150 |
+
],
|
151 |
+
"execution_count": 5,
|
152 |
+
"metadata": {
|
153 |
+
"collapsed": false
|
154 |
+
}
|
155 |
+
},
|
156 |
+
{
|
157 |
+
"cell_type": "code",
|
158 |
+
"source": [
|
159 |
+
"!nvidia-smi"
|
160 |
+
],
|
161 |
+
"outputs": [
|
162 |
+
{
|
163 |
+
"output_type": "stream",
|
164 |
+
"name": "stdout",
|
165 |
+
"text": "Sun Jan 28 02:27:31 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 28C P0 37W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 27C P0 36W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
166 |
+
}
|
167 |
+
],
|
168 |
+
"execution_count": 6,
|
169 |
+
"metadata": {
|
170 |
+
"datalore": {
|
171 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
172 |
+
"type": "CODE",
|
173 |
+
"hide_input_from_viewers": true,
|
174 |
+
"hide_output_from_viewers": true
|
175 |
+
}
|
176 |
+
}
|
177 |
+
},
|
178 |
+
{
|
179 |
+
"attachments": {},
|
180 |
+
"cell_type": "markdown",
|
181 |
+
"source": [
|
182 |
+
"## Loading the data set"
|
183 |
+
],
|
184 |
+
"metadata": {
|
185 |
+
"datalore": {
|
186 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
187 |
+
"type": "MD",
|
188 |
+
"hide_input_from_viewers": false,
|
189 |
+
"hide_output_from_viewers": false,
|
190 |
+
"report_properties": {
|
191 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
192 |
+
}
|
193 |
+
}
|
194 |
+
}
|
195 |
+
},
|
196 |
+
{
|
197 |
+
"cell_type": "code",
|
198 |
+
"source": [
|
199 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
200 |
+
],
|
201 |
+
"outputs": [],
|
202 |
+
"execution_count": 7,
|
203 |
+
"metadata": {
|
204 |
+
"collapsed": false,
|
205 |
+
"gather": {
|
206 |
+
"logged": 1706408853264
|
207 |
+
}
|
208 |
+
}
|
209 |
+
},
|
210 |
+
{
|
211 |
+
"cell_type": "markdown",
|
212 |
+
"source": [
|
213 |
+
"### Tokenisation and encoding"
|
214 |
+
],
|
215 |
+
"metadata": {
|
216 |
+
"collapsed": false
|
217 |
+
}
|
218 |
+
},
|
219 |
+
{
|
220 |
+
"cell_type": "code",
|
221 |
+
"source": [
|
222 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
|
223 |
+
],
|
224 |
+
"outputs": [],
|
225 |
+
"execution_count": 8,
|
226 |
+
"metadata": {
|
227 |
+
"datalore": {
|
228 |
+
"node_id": "I7n646PIscsUZRoHu6m7zm",
|
229 |
+
"type": "CODE",
|
230 |
+
"hide_input_from_viewers": true,
|
231 |
+
"hide_output_from_viewers": true
|
232 |
+
},
|
233 |
+
"gather": {
|
234 |
+
"logged": 1706408853475
|
235 |
+
}
|
236 |
+
}
|
237 |
+
},
|
238 |
+
{
|
239 |
+
"cell_type": "code",
|
240 |
+
"source": [
|
241 |
+
"def tokenize_and_encode(examples):\n",
|
242 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
243 |
+
],
|
244 |
+
"outputs": [],
|
245 |
+
"execution_count": 9,
|
246 |
+
"metadata": {
|
247 |
+
"datalore": {
|
248 |
+
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
|
249 |
+
"type": "CODE",
|
250 |
+
"hide_input_from_viewers": true,
|
251 |
+
"hide_output_from_viewers": true
|
252 |
+
},
|
253 |
+
"gather": {
|
254 |
+
"logged": 1706408853684
|
255 |
+
}
|
256 |
+
}
|
257 |
+
},
|
258 |
+
{
|
259 |
+
"cell_type": "code",
|
260 |
+
"source": [
|
261 |
+
"cols = dataset[\"train\"].column_names\n",
|
262 |
+
"cols.remove(\"labels\")\n",
|
263 |
+
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
264 |
+
],
|
265 |
+
"outputs": [
|
266 |
+
{
|
267 |
+
"output_type": "stream",
|
268 |
+
"name": "stderr",
|
269 |
+
"text": "Map: 100%|██████████| 15786/15786 [00:01<00:00, 10990.82 examples/s]\n"
|
270 |
+
}
|
271 |
+
],
|
272 |
+
"execution_count": 10,
|
273 |
+
"metadata": {
|
274 |
+
"datalore": {
|
275 |
+
"node_id": "slHeNysZOX9uWS9PB7jFDb",
|
276 |
+
"type": "CODE",
|
277 |
+
"hide_input_from_viewers": true,
|
278 |
+
"hide_output_from_viewers": true
|
279 |
+
},
|
280 |
+
"gather": {
|
281 |
+
"logged": 1706408854738
|
282 |
+
}
|
283 |
+
}
|
284 |
+
},
|
285 |
+
{
|
286 |
+
"cell_type": "markdown",
|
287 |
+
"source": [
|
288 |
+
"### Training"
|
289 |
+
],
|
290 |
+
"metadata": {
|
291 |
+
"collapsed": false
|
292 |
+
}
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "code",
|
296 |
+
"source": [
|
297 |
+
"class MultiLabelTrainer(Trainer):\n",
|
298 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
299 |
+
" labels = inputs.pop(\"labels\")\n",
|
300 |
+
" outputs = model(**inputs)\n",
|
301 |
+
" logits = outputs.logits\n",
|
302 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
303 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
304 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
305 |
+
" return (loss, outputs) if return_outputs else loss"
|
306 |
+
],
|
307 |
+
"outputs": [],
|
308 |
+
"execution_count": 11,
|
309 |
+
"metadata": {
|
310 |
+
"datalore": {
|
311 |
+
"node_id": "itXWkbDw9sqbkMuDP84QoT",
|
312 |
+
"type": "CODE",
|
313 |
+
"hide_input_from_viewers": true,
|
314 |
+
"hide_output_from_viewers": true
|
315 |
+
},
|
316 |
+
"gather": {
|
317 |
+
"logged": 1706408854925
|
318 |
+
}
|
319 |
+
}
|
320 |
+
},
|
321 |
+
{
|
322 |
+
"cell_type": "code",
|
323 |
+
"source": [
|
324 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
|
325 |
+
],
|
326 |
+
"outputs": [
|
327 |
+
{
|
328 |
+
"output_type": "stream",
|
329 |
+
"name": "stderr",
|
330 |
+
"text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
331 |
+
}
|
332 |
+
],
|
333 |
+
"execution_count": 12,
|
334 |
+
"metadata": {
|
335 |
+
"datalore": {
|
336 |
+
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
|
337 |
+
"type": "CODE",
|
338 |
+
"hide_input_from_viewers": true,
|
339 |
+
"hide_output_from_viewers": true
|
340 |
+
},
|
341 |
+
"gather": {
|
342 |
+
"logged": 1706408857008
|
343 |
+
}
|
344 |
+
}
|
345 |
+
},
|
346 |
+
{
|
347 |
+
"cell_type": "code",
|
348 |
+
"source": [
|
349 |
+
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
|
350 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
351 |
+
" y_true = torch.from_numpy(y_true)\n",
|
352 |
+
"\n",
|
353 |
+
" if sigmoid:\n",
|
354 |
+
" y_pred = y_pred.sigmoid()\n",
|
355 |
+
"\n",
|
356 |
+
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
|
357 |
+
],
|
358 |
+
"outputs": [],
|
359 |
+
"execution_count": 13,
|
360 |
+
"metadata": {
|
361 |
+
"datalore": {
|
362 |
+
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
|
363 |
+
"type": "CODE",
|
364 |
+
"hide_input_from_viewers": true,
|
365 |
+
"hide_output_from_viewers": true
|
366 |
+
},
|
367 |
+
"gather": {
|
368 |
+
"logged": 1706408857297
|
369 |
+
}
|
370 |
+
}
|
371 |
+
},
|
372 |
+
{
|
373 |
+
"cell_type": "code",
|
374 |
+
"source": [
|
375 |
+
"def compute_metrics(eval_pred):\n",
|
376 |
+
" predictions, labels = eval_pred\n",
|
377 |
+
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
|
378 |
+
],
|
379 |
+
"outputs": [],
|
380 |
+
"execution_count": 14,
|
381 |
+
"metadata": {
|
382 |
+
"datalore": {
|
383 |
+
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
|
384 |
+
"type": "CODE",
|
385 |
+
"hide_input_from_viewers": true,
|
386 |
+
"hide_output_from_viewers": true
|
387 |
+
},
|
388 |
+
"gather": {
|
389 |
+
"logged": 1706408857499
|
390 |
+
}
|
391 |
+
}
|
392 |
+
},
|
393 |
+
{
|
394 |
+
"cell_type": "code",
|
395 |
+
"source": [
|
396 |
+
"args = TrainingArguments(\n",
|
397 |
+
" output_dir=\"vaers\",\n",
|
398 |
+
" evaluation_strategy=\"epoch\",\n",
|
399 |
+
" learning_rate=2e-5,\n",
|
400 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
401 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
402 |
+
" num_train_epochs=EPOCHS,\n",
|
403 |
+
" weight_decay=.01,\n",
|
404 |
+
" report_to=[\"wandb\"]\n",
|
405 |
+
")"
|
406 |
+
],
|
407 |
+
"outputs": [],
|
408 |
+
"execution_count": 15,
|
409 |
+
"metadata": {
|
410 |
+
"datalore": {
|
411 |
+
"node_id": "1iPZOTKPwSkTgX5dORqT89",
|
412 |
+
"type": "CODE",
|
413 |
+
"hide_input_from_viewers": true,
|
414 |
+
"hide_output_from_viewers": true
|
415 |
+
},
|
416 |
+
"gather": {
|
417 |
+
"logged": 1706408857680
|
418 |
+
}
|
419 |
+
}
|
420 |
+
},
|
421 |
+
{
|
422 |
+
"cell_type": "code",
|
423 |
+
"source": [
|
424 |
+
"multi_label_trainer = MultiLabelTrainer(\n",
|
425 |
+
" model, \n",
|
426 |
+
" args, \n",
|
427 |
+
" train_dataset=ds_enc[\"train\"], \n",
|
428 |
+
" eval_dataset=ds_enc[\"test\"], \n",
|
429 |
+
" compute_metrics=compute_metrics, \n",
|
430 |
+
" tokenizer=tokenizer\n",
|
431 |
+
")"
|
432 |
+
],
|
433 |
+
"outputs": [
|
434 |
+
{
|
435 |
+
"output_type": "stream",
|
436 |
+
"name": "stderr",
|
437 |
+
"text": "huggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
438 |
+
}
|
439 |
+
],
|
440 |
+
"execution_count": 18,
|
441 |
+
"metadata": {
|
442 |
+
"datalore": {
|
443 |
+
"node_id": "bnRkNvRYltLun6gCEgL7v0",
|
444 |
+
"type": "CODE",
|
445 |
+
"hide_input_from_viewers": true,
|
446 |
+
"hide_output_from_viewers": true
|
447 |
+
},
|
448 |
+
"gather": {
|
449 |
+
"logged": 1706408895305
|
450 |
+
}
|
451 |
+
}
|
452 |
+
},
|
453 |
+
{
|
454 |
+
"cell_type": "code",
|
455 |
+
"source": [
|
456 |
+
"multi_label_trainer.evaluate()"
|
457 |
+
],
|
458 |
+
"outputs": [
|
459 |
+
{
|
460 |
+
"output_type": "display_data",
|
461 |
+
"data": {
|
462 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
463 |
+
"text/html": "\n <div>\n \n <progress value='1974' max='987' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [987/987 21:41]\n </div>\n "
|
464 |
+
},
|
465 |
+
"metadata": {}
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"output_type": "stream",
|
469 |
+
"name": "stderr",
|
470 |
+
"text": "Failed to detect the name of this notebook, you can set it manually with the WANDB_NOTEBOOK_NAME environment variable to enable code saving.\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\nhuggingface/tokenizers: The current process just got forked, after parallelism has already been used. Disabling parallelism to avoid deadlocks...\nTo disable this warning, you can either:\n\t- Avoid using `tokenizers` before the fork if possible\n\t- Explicitly set the environment variable TOKENIZERS_PARALLELISM=(true | false)\n"
|
471 |
+
},
|
472 |
+
{
|
473 |
+
"output_type": "display_data",
|
474 |
+
"data": {
|
475 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
476 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
477 |
+
},
|
478 |
+
"metadata": {}
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"output_type": "display_data",
|
482 |
+
"data": {
|
483 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
484 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_022947-hh1sxw9i</code>"
|
485 |
+
},
|
486 |
+
"metadata": {}
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"output_type": "display_data",
|
490 |
+
"data": {
|
491 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
492 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i' target=\"_blank\">icy-firebrand-1</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
493 |
+
},
|
494 |
+
"metadata": {}
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"output_type": "display_data",
|
498 |
+
"data": {
|
499 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
500 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
|
501 |
+
},
|
502 |
+
"metadata": {}
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"output_type": "display_data",
|
506 |
+
"data": {
|
507 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
508 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/hh1sxw9i</a>"
|
509 |
+
},
|
510 |
+
"metadata": {}
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"output_type": "execute_result",
|
514 |
+
"execution_count": 19,
|
515 |
+
"data": {
|
516 |
+
"text/plain": "{'eval_loss': 0.7153111100196838,\n 'eval_accuracy_thresh': 0.2938227355480194,\n 'eval_runtime': 82.3613,\n 'eval_samples_per_second': 191.668,\n 'eval_steps_per_second': 11.984}"
|
517 |
+
},
|
518 |
+
"metadata": {}
|
519 |
+
}
|
520 |
+
],
|
521 |
+
"execution_count": 19,
|
522 |
+
"metadata": {
|
523 |
+
"datalore": {
|
524 |
+
"node_id": "LO54PlDkWQdFrzV25FvduB",
|
525 |
+
"type": "CODE",
|
526 |
+
"hide_input_from_viewers": true,
|
527 |
+
"hide_output_from_viewers": true
|
528 |
+
},
|
529 |
+
"gather": {
|
530 |
+
"logged": 1706408991752
|
531 |
+
}
|
532 |
+
}
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"cell_type": "code",
|
536 |
+
"source": [
|
537 |
+
"multi_label_trainer.train()"
|
538 |
+
],
|
539 |
+
"outputs": [
|
540 |
+
{
|
541 |
+
"output_type": "display_data",
|
542 |
+
"data": {
|
543 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
544 |
+
"text/html": "\n <div>\n \n <progress value='4605' max='4605' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [4605/4605 20:25, Epoch 1/1]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n <th>Accuracy Thresh</th>\n </tr>\n </thead>\n <tbody>\n <tr>\n <td>1</td>\n <td>0.086700</td>\n <td>0.093388</td>\n <td>0.962897</td>\n </tr>\n </tbody>\n</table><p>"
|
545 |
+
},
|
546 |
+
"metadata": {}
|
547 |
+
},
|
548 |
+
{
|
549 |
+
"output_type": "stream",
|
550 |
+
"name": "stderr",
|
551 |
+
"text": "Checkpoint destination directory vaers/checkpoint-500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-500)... Done. 15.9s\nCheckpoint destination directory vaers/checkpoint-1000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1000)... Done. 12.5s\nCheckpoint destination directory vaers/checkpoint-1500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-1500)... Done. 21.9s\nCheckpoint destination directory vaers/checkpoint-2000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2000)... Done. 13.8s\nCheckpoint destination directory vaers/checkpoint-2500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-2500)... Done. 15.7s\nCheckpoint destination directory vaers/checkpoint-3000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3000)... Done. 21.7s\nCheckpoint destination directory vaers/checkpoint-3500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-3500)... Done. 10.6s\nCheckpoint destination directory vaers/checkpoint-4000 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4000)... Done. 15.0s\nCheckpoint destination directory vaers/checkpoint-4500 already exists and is non-empty.Saving will proceed but saved results may be invalid.\n\u001b[34m\u001b[1mwandb\u001b[0m: Adding directory to artifact (./vaers/checkpoint-4500)... Done. 16.7s\n"
|
552 |
+
},
|
553 |
+
{
|
554 |
+
"output_type": "execute_result",
|
555 |
+
"execution_count": 21,
|
556 |
+
"data": {
|
557 |
+
"text/plain": "TrainOutput(global_step=4605, training_loss=0.09062977189220382, metrics={'train_runtime': 1223.2444, 'train_samples_per_second': 60.223, 'train_steps_per_second': 3.765, 'total_flos': 9346797199425174.0, 'train_loss': 0.09062977189220382, 'epoch': 1.0})"
|
558 |
+
},
|
559 |
+
"metadata": {}
|
560 |
+
}
|
561 |
+
],
|
562 |
+
"execution_count": 21,
|
563 |
+
"metadata": {
|
564 |
+
"datalore": {
|
565 |
+
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
|
566 |
+
"type": "CODE",
|
567 |
+
"hide_input_from_viewers": true,
|
568 |
+
"hide_output_from_viewers": true
|
569 |
+
},
|
570 |
+
"gather": {
|
571 |
+
"logged": 1706411445752
|
572 |
+
}
|
573 |
+
}
|
574 |
+
},
|
575 |
+
{
|
576 |
+
"cell_type": "markdown",
|
577 |
+
"source": [
|
578 |
+
"### Evaluation"
|
579 |
+
],
|
580 |
+
"metadata": {
|
581 |
+
"collapsed": false
|
582 |
+
}
|
583 |
+
},
|
584 |
+
{
|
585 |
+
"cell_type": "markdown",
|
586 |
+
"source": [
|
587 |
+
"We instantiate a classifier `pipeline` and push it to CUDA."
|
588 |
+
],
|
589 |
+
"metadata": {
|
590 |
+
"collapsed": false
|
591 |
+
}
|
592 |
+
},
|
593 |
+
{
|
594 |
+
"cell_type": "code",
|
595 |
+
"source": [
|
596 |
+
"classifier = pipeline(\"text-classification\", \n",
|
597 |
+
" model, \n",
|
598 |
+
" tokenizer=tokenizer, \n",
|
599 |
+
" device=\"cuda:0\")"
|
600 |
+
],
|
601 |
+
"outputs": [],
|
602 |
+
"execution_count": 24,
|
603 |
+
"metadata": {
|
604 |
+
"datalore": {
|
605 |
+
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
|
606 |
+
"type": "CODE",
|
607 |
+
"hide_input_from_viewers": true,
|
608 |
+
"hide_output_from_viewers": true
|
609 |
+
},
|
610 |
+
"gather": {
|
611 |
+
"logged": 1706411459928
|
612 |
+
}
|
613 |
+
}
|
614 |
+
},
|
615 |
+
{
|
616 |
+
"cell_type": "markdown",
|
617 |
+
"source": [
|
618 |
+
"We use the same tokenizer used for training to tokenize/encode the validation set."
|
619 |
+
],
|
620 |
+
"metadata": {
|
621 |
+
"collapsed": false
|
622 |
+
}
|
623 |
+
},
|
624 |
+
{
|
625 |
+
"cell_type": "code",
|
626 |
+
"source": [
|
627 |
+
"test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
|
628 |
+
" max_length=255, \n",
|
629 |
+
" pad_to_max_length=True, \n",
|
630 |
+
" return_token_type_ids=True, \n",
|
631 |
+
" truncation=True)"
|
632 |
+
],
|
633 |
+
"outputs": [
|
634 |
+
{
|
635 |
+
"output_type": "error",
|
636 |
+
"ename": "KeyError",
|
637 |
+
"evalue": "'validate'",
|
638 |
+
"traceback": [
|
639 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
640 |
+
"\u001b[0;31mKeyError\u001b[0m Traceback (most recent call last)",
|
641 |
+
"Cell \u001b[0;32mIn[25], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m test_encodings \u001b[38;5;241m=\u001b[39m tokenizer\u001b[38;5;241m.\u001b[39mbatch_encode_plus(\u001b[43mdataset\u001b[49m\u001b[43m[\u001b[49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[38;5;124;43mvalidate\u001b[39;49m\u001b[38;5;124;43m\"\u001b[39;49m\u001b[43m]\u001b[49m[\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mtext\u001b[39m\u001b[38;5;124m\"\u001b[39m], \n\u001b[1;32m 2\u001b[0m max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;241m255\u001b[39m, \n\u001b[1;32m 3\u001b[0m pad_to_max_length\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \n\u001b[1;32m 4\u001b[0m return_token_type_ids\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m, \n\u001b[1;32m 5\u001b[0m truncation\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mTrue\u001b[39;00m)\n",
|
642 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/datasets/dataset_dict.py:74\u001b[0m, in \u001b[0;36mDatasetDict.__getitem__\u001b[0;34m(self, k)\u001b[0m\n\u001b[1;32m 72\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21m__getitem__\u001b[39m(\u001b[38;5;28mself\u001b[39m, k) \u001b[38;5;241m-\u001b[39m\u001b[38;5;241m>\u001b[39m Dataset:\n\u001b[1;32m 73\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(k, (\u001b[38;5;28mstr\u001b[39m, NamedSplit)) \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mlen\u001b[39m(\u001b[38;5;28mself\u001b[39m) \u001b[38;5;241m==\u001b[39m \u001b[38;5;241m0\u001b[39m:\n\u001b[0;32m---> 74\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28;43msuper\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[38;5;21;43m__getitem__\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mk\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 75\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[1;32m 76\u001b[0m available_suggested_splits \u001b[38;5;241m=\u001b[39m [\n\u001b[1;32m 77\u001b[0m split \u001b[38;5;28;01mfor\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m (Split\u001b[38;5;241m.\u001b[39mTRAIN, Split\u001b[38;5;241m.\u001b[39mTEST, Split\u001b[38;5;241m.\u001b[39mVALIDATION) \u001b[38;5;28;01mif\u001b[39;00m split \u001b[38;5;129;01min\u001b[39;00m \u001b[38;5;28mself\u001b[39m\n\u001b[1;32m 78\u001b[0m ]\n",
|
643 |
+
"\u001b[0;31mKeyError\u001b[0m: 'validate'"
|
644 |
+
]
|
645 |
+
}
|
646 |
+
],
|
647 |
+
"execution_count": 25,
|
648 |
+
"metadata": {
|
649 |
+
"datalore": {
|
650 |
+
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
|
651 |
+
"type": "CODE",
|
652 |
+
"hide_input_from_viewers": true,
|
653 |
+
"hide_output_from_viewers": true
|
654 |
+
},
|
655 |
+
"gather": {
|
656 |
+
"logged": 1706411465538
|
657 |
+
}
|
658 |
+
}
|
659 |
+
},
|
660 |
+
{
|
661 |
+
"cell_type": "markdown",
|
662 |
+
"source": [
|
663 |
+
"Once we've made the data loadable by putting it into a `DataLoader`, we "
|
664 |
+
],
|
665 |
+
"metadata": {
|
666 |
+
"collapsed": false
|
667 |
+
}
|
668 |
+
},
|
669 |
+
{
|
670 |
+
"cell_type": "code",
|
671 |
+
"source": [
|
672 |
+
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
|
673 |
+
" torch.tensor(test_encodings['attention_mask']), \n",
|
674 |
+
" torch.tensor(ds_enc[\"validate\"][\"labels\"]), \n",
|
675 |
+
" torch.tensor(test_encodings['token_type_ids']))\n",
|
676 |
+
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
|
677 |
+
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
|
678 |
+
" batch_size=BATCH_SIZE)"
|
679 |
+
],
|
680 |
+
"outputs": [],
|
681 |
+
"execution_count": null,
|
682 |
+
"metadata": {
|
683 |
+
"datalore": {
|
684 |
+
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
|
685 |
+
"type": "CODE",
|
686 |
+
"hide_input_from_viewers": true,
|
687 |
+
"hide_output_from_viewers": true
|
688 |
+
},
|
689 |
+
"gather": {
|
690 |
+
"logged": 1706411446707
|
691 |
+
}
|
692 |
+
}
|
693 |
+
},
|
694 |
+
{
|
695 |
+
"cell_type": "code",
|
696 |
+
"source": [
|
697 |
+
"model.eval()\n",
|
698 |
+
"\n",
|
699 |
+
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
|
700 |
+
"\n",
|
701 |
+
"for i, batch in enumerate(test_dataloader):\n",
|
702 |
+
" batch = tuple(t.to(device) for t in batch)\n",
|
703 |
+
" # Unpack the inputs from our dataloader\n",
|
704 |
+
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
|
705 |
+
" \n",
|
706 |
+
" with torch.no_grad():\n",
|
707 |
+
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
|
708 |
+
" b_logit_pred = outs[0]\n",
|
709 |
+
" pred_label = torch.sigmoid(b_logit_pred)\n",
|
710 |
+
"\n",
|
711 |
+
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
|
712 |
+
" pred_label = pred_label.to('cpu').numpy()\n",
|
713 |
+
" b_labels = b_labels.to('cpu').numpy()\n",
|
714 |
+
"\n",
|
715 |
+
" tokenized_texts.append(b_input_ids)\n",
|
716 |
+
" logit_preds.append(b_logit_pred)\n",
|
717 |
+
" true_labels.append(b_labels)\n",
|
718 |
+
" pred_labels.append(pred_label)\n",
|
719 |
+
"\n",
|
720 |
+
"# Flatten outputs\n",
|
721 |
+
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
|
722 |
+
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
|
723 |
+
"true_labels = [item for sublist in true_labels for item in sublist]\n",
|
724 |
+
"\n",
|
725 |
+
"# Converting flattened binary values to boolean values\n",
|
726 |
+
"true_bools = [tl == 1 for tl in true_labels]\n",
|
727 |
+
"pred_bools = [pl > 0.50 for pl in pred_labels] "
|
728 |
+
],
|
729 |
+
"outputs": [],
|
730 |
+
"execution_count": null,
|
731 |
+
"metadata": {
|
732 |
+
"datalore": {
|
733 |
+
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
|
734 |
+
"type": "CODE",
|
735 |
+
"hide_input_from_viewers": true,
|
736 |
+
"hide_output_from_viewers": true
|
737 |
+
},
|
738 |
+
"gather": {
|
739 |
+
"logged": 1706411446723
|
740 |
+
}
|
741 |
+
}
|
742 |
+
},
|
743 |
+
{
|
744 |
+
"cell_type": "markdown",
|
745 |
+
"source": [
|
746 |
+
"We create a classification report:"
|
747 |
+
],
|
748 |
+
"metadata": {
|
749 |
+
"collapsed": false
|
750 |
+
}
|
751 |
+
},
|
752 |
+
{
|
753 |
+
"cell_type": "code",
|
754 |
+
"source": [
|
755 |
+
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
|
756 |
+
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
|
757 |
+
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
|
758 |
+
"print(clf_report)"
|
759 |
+
],
|
760 |
+
"outputs": [],
|
761 |
+
"execution_count": null,
|
762 |
+
"metadata": {
|
763 |
+
"datalore": {
|
764 |
+
"node_id": "eBprrgF086mznPbPVBpOLS",
|
765 |
+
"type": "CODE",
|
766 |
+
"hide_input_from_viewers": true,
|
767 |
+
"hide_output_from_viewers": true
|
768 |
+
},
|
769 |
+
"gather": {
|
770 |
+
"logged": 1706411446746
|
771 |
+
}
|
772 |
+
}
|
773 |
+
},
|
774 |
+
{
|
775 |
+
"cell_type": "markdown",
|
776 |
+
"source": [
|
777 |
+
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
|
778 |
+
],
|
779 |
+
"metadata": {
|
780 |
+
"collapsed": false
|
781 |
+
}
|
782 |
+
},
|
783 |
+
{
|
784 |
+
"cell_type": "code",
|
785 |
+
"source": [
|
786 |
+
"# Creating a map of class names from class numbers\n",
|
787 |
+
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
|
788 |
+
],
|
789 |
+
"outputs": [],
|
790 |
+
"execution_count": null,
|
791 |
+
"metadata": {
|
792 |
+
"datalore": {
|
793 |
+
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
|
794 |
+
"type": "CODE",
|
795 |
+
"hide_input_from_viewers": true,
|
796 |
+
"hide_output_from_viewers": true
|
797 |
+
},
|
798 |
+
"gather": {
|
799 |
+
"logged": 1706411446758
|
800 |
+
}
|
801 |
+
}
|
802 |
+
},
|
803 |
+
{
|
804 |
+
"cell_type": "code",
|
805 |
+
"source": [
|
806 |
+
"true_label_idxs, pred_label_idxs = [], []\n",
|
807 |
+
"\n",
|
808 |
+
"for vals in true_bools:\n",
|
809 |
+
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
|
810 |
+
"for vals in pred_bools:\n",
|
811 |
+
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
|
812 |
+
],
|
813 |
+
"outputs": [],
|
814 |
+
"execution_count": null,
|
815 |
+
"metadata": {
|
816 |
+
"datalore": {
|
817 |
+
"node_id": "jH0S35dDteUch01sa6me6e",
|
818 |
+
"type": "CODE",
|
819 |
+
"hide_input_from_viewers": true,
|
820 |
+
"hide_output_from_viewers": true
|
821 |
+
},
|
822 |
+
"gather": {
|
823 |
+
"logged": 1706411446771
|
824 |
+
}
|
825 |
+
}
|
826 |
+
},
|
827 |
+
{
|
828 |
+
"cell_type": "code",
|
829 |
+
"source": [
|
830 |
+
"true_label_texts, pred_label_texts = [], []\n",
|
831 |
+
"\n",
|
832 |
+
"for vals in true_label_idxs:\n",
|
833 |
+
" if vals:\n",
|
834 |
+
" true_label_texts.append([idx2label[val] for val in vals])\n",
|
835 |
+
" else:\n",
|
836 |
+
" true_label_texts.append(vals)\n",
|
837 |
+
"\n",
|
838 |
+
"for vals in pred_label_idxs:\n",
|
839 |
+
" if vals:\n",
|
840 |
+
" pred_label_texts.append([idx2label[val] for val in vals])\n",
|
841 |
+
" else:\n",
|
842 |
+
" pred_label_texts.append(vals)"
|
843 |
+
],
|
844 |
+
"outputs": [],
|
845 |
+
"execution_count": null,
|
846 |
+
"metadata": {
|
847 |
+
"datalore": {
|
848 |
+
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
|
849 |
+
"type": "CODE",
|
850 |
+
"hide_input_from_viewers": true,
|
851 |
+
"hide_output_from_viewers": true
|
852 |
+
},
|
853 |
+
"gather": {
|
854 |
+
"logged": 1706411446785
|
855 |
+
}
|
856 |
+
}
|
857 |
+
},
|
858 |
+
{
|
859 |
+
"cell_type": "code",
|
860 |
+
"source": [
|
861 |
+
"symptom_texts = [tokenizer.decode(text,\n",
|
862 |
+
" skip_special_tokens=True,\n",
|
863 |
+
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
|
864 |
+
],
|
865 |
+
"outputs": [],
|
866 |
+
"execution_count": null,
|
867 |
+
"metadata": {
|
868 |
+
"datalore": {
|
869 |
+
"node_id": "SxUmVHfQISEeptg1SawOmB",
|
870 |
+
"type": "CODE",
|
871 |
+
"hide_input_from_viewers": true,
|
872 |
+
"hide_output_from_viewers": true
|
873 |
+
},
|
874 |
+
"gather": {
|
875 |
+
"logged": 1706411446805
|
876 |
+
}
|
877 |
+
}
|
878 |
+
},
|
879 |
+
{
|
880 |
+
"cell_type": "code",
|
881 |
+
"source": [
|
882 |
+
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
|
883 |
+
" 'true_labels': true_label_texts, \n",
|
884 |
+
" 'pred_labels':pred_label_texts})\n",
|
885 |
+
"comparisons_df.to_csv('comparisons.csv')\n",
|
886 |
+
"comparisons_df"
|
887 |
+
],
|
888 |
+
"outputs": [],
|
889 |
+
"execution_count": null,
|
890 |
+
"metadata": {
|
891 |
+
"datalore": {
|
892 |
+
"node_id": "BxFNigNGRLTOqraI55BPSH",
|
893 |
+
"type": "CODE",
|
894 |
+
"hide_input_from_viewers": true,
|
895 |
+
"hide_output_from_viewers": true
|
896 |
+
},
|
897 |
+
"gather": {
|
898 |
+
"logged": 1706411446818
|
899 |
+
}
|
900 |
+
}
|
901 |
+
},
|
902 |
+
{
|
903 |
+
"cell_type": "markdown",
|
904 |
+
"source": [
|
905 |
+
"### Shapley analysis"
|
906 |
+
],
|
907 |
+
"metadata": {
|
908 |
+
"collapsed": false
|
909 |
+
}
|
910 |
+
},
|
911 |
+
{
|
912 |
+
"cell_type": "code",
|
913 |
+
"source": [
|
914 |
+
"explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
|
915 |
+
],
|
916 |
+
"outputs": [],
|
917 |
+
"execution_count": null,
|
918 |
+
"metadata": {
|
919 |
+
"datalore": {
|
920 |
+
"node_id": "OpdZcoenX2HwzLdai7K5UA",
|
921 |
+
"type": "CODE",
|
922 |
+
"hide_input_from_viewers": true,
|
923 |
+
"hide_output_from_viewers": true
|
924 |
+
},
|
925 |
+
"gather": {
|
926 |
+
"logged": 1706411446829
|
927 |
+
}
|
928 |
+
}
|
929 |
+
},
|
930 |
+
{
|
931 |
+
"cell_type": "code",
|
932 |
+
"source": [
|
933 |
+
"shap_values = explainer(dataset[\"validate\"][\"text\"][1:2])"
|
934 |
+
],
|
935 |
+
"outputs": [],
|
936 |
+
"execution_count": null,
|
937 |
+
"metadata": {
|
938 |
+
"datalore": {
|
939 |
+
"node_id": "FvbCMfIDlcf16YSvb8wNQv",
|
940 |
+
"type": "CODE",
|
941 |
+
"hide_input_from_viewers": true,
|
942 |
+
"hide_output_from_viewers": true
|
943 |
+
},
|
944 |
+
"gather": {
|
945 |
+
"logged": 1706411446839
|
946 |
+
}
|
947 |
+
}
|
948 |
+
},
|
949 |
+
{
|
950 |
+
"cell_type": "code",
|
951 |
+
"source": [
|
952 |
+
"shap.plots.text(shap_values)"
|
953 |
+
],
|
954 |
+
"outputs": [],
|
955 |
+
"execution_count": null,
|
956 |
+
"metadata": {
|
957 |
+
"datalore": {
|
958 |
+
"node_id": "TSxvakWLPCpjVMWi9ZdEbd",
|
959 |
+
"type": "CODE",
|
960 |
+
"hide_input_from_viewers": true,
|
961 |
+
"hide_output_from_viewers": true
|
962 |
+
},
|
963 |
+
"gather": {
|
964 |
+
"logged": 1706411446848
|
965 |
+
}
|
966 |
+
}
|
967 |
+
},
|
968 |
+
{
|
969 |
+
"cell_type": "code",
|
970 |
+
"source": [],
|
971 |
+
"outputs": [],
|
972 |
+
"execution_count": null,
|
973 |
+
"metadata": {
|
974 |
+
"jupyter": {
|
975 |
+
"source_hidden": false,
|
976 |
+
"outputs_hidden": false
|
977 |
+
},
|
978 |
+
"nteract": {
|
979 |
+
"transient": {
|
980 |
+
"deleting": false
|
981 |
+
}
|
982 |
+
}
|
983 |
+
}
|
984 |
+
}
|
985 |
+
],
|
986 |
+
"metadata": {
|
987 |
+
"kernelspec": {
|
988 |
+
"name": "python3",
|
989 |
+
"language": "python",
|
990 |
+
"display_name": "Python 3 (ipykernel)"
|
991 |
+
},
|
992 |
+
"datalore": {
|
993 |
+
"computation_mode": "JUPYTER",
|
994 |
+
"package_manager": "pip",
|
995 |
+
"base_environment": "default",
|
996 |
+
"packages": [
|
997 |
+
{
|
998 |
+
"name": "datasets",
|
999 |
+
"version": "2.16.1",
|
1000 |
+
"source": "PIP"
|
1001 |
+
},
|
1002 |
+
{
|
1003 |
+
"name": "torch",
|
1004 |
+
"version": "2.1.2",
|
1005 |
+
"source": "PIP"
|
1006 |
+
},
|
1007 |
+
{
|
1008 |
+
"name": "accelerate",
|
1009 |
+
"version": "0.26.1",
|
1010 |
+
"source": "PIP"
|
1011 |
+
}
|
1012 |
+
],
|
1013 |
+
"report_row_ids": [
|
1014 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
1015 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
1016 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
1017 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
1018 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
1019 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
1020 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
1021 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
1022 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
1023 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
1024 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
1025 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
1026 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
1027 |
+
],
|
1028 |
+
"version": 3
|
1029 |
+
},
|
1030 |
+
"microsoft": {
|
1031 |
+
"ms_spell_check": {
|
1032 |
+
"ms_spell_check_language": "en"
|
1033 |
+
}
|
1034 |
+
},
|
1035 |
+
"language_info": {
|
1036 |
+
"name": "python",
|
1037 |
+
"version": "3.8.5",
|
1038 |
+
"mimetype": "text/x-python",
|
1039 |
+
"codemirror_mode": {
|
1040 |
+
"name": "ipython",
|
1041 |
+
"version": 3
|
1042 |
+
},
|
1043 |
+
"pygments_lexer": "ipython3",
|
1044 |
+
"nbconvert_exporter": "python",
|
1045 |
+
"file_extension": ".py"
|
1046 |
+
},
|
1047 |
+
"nteract": {
|
1048 |
+
"version": "nteract-front-end@1.0.0"
|
1049 |
+
}
|
1050 |
+
},
|
1051 |
+
"nbformat": 4,
|
1052 |
+
"nbformat_minor": 4
|
1053 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-28-4-13-53Z.ipynb
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-14-26-30Z.ipynb
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {}
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"source": [
|
15 |
+
"%pip install accelerate -U"
|
16 |
+
],
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"output_type": "stream",
|
20 |
+
"name": "stdout",
|
21 |
+
"text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nNote: you may need to restart the kernel to use updated packages.\n"
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"execution_count": 1,
|
25 |
+
"metadata": {
|
26 |
+
"gather": {
|
27 |
+
"logged": 1706475754655
|
28 |
+
},
|
29 |
+
"nteract": {
|
30 |
+
"transient": {
|
31 |
+
"deleting": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"tags": []
|
35 |
+
}
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"source": [
|
40 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
41 |
+
],
|
42 |
+
"outputs": [
|
43 |
+
{
|
44 |
+
"output_type": "stream",
|
45 |
+
"name": "stdout",
|
46 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"execution_count": 2,
|
50 |
+
"metadata": {
|
51 |
+
"nteract": {
|
52 |
+
"transient": {
|
53 |
+
"deleting": false
|
54 |
+
}
|
55 |
+
}
|
56 |
+
}
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"import pandas as pd\n",
|
62 |
+
"import numpy as np\n",
|
63 |
+
"import torch\n",
|
64 |
+
"import os\n",
|
65 |
+
"from typing import List, Union\n",
|
66 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
|
67 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
68 |
+
"import shap\n",
|
69 |
+
"import wandb\n",
|
70 |
+
"import evaluate\n",
|
71 |
+
"from codecarbon import EmissionsTracker\n",
|
72 |
+
"import logging\n",
|
73 |
+
"\n",
|
74 |
+
"wandb.finish()\n",
|
75 |
+
"\n",
|
76 |
+
"logging.getLogger('codecarbon').propagate = False\n",
|
77 |
+
"\n",
|
78 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
79 |
+
"tracker = EmissionsTracker()\n",
|
80 |
+
"\n",
|
81 |
+
"%load_ext watermark"
|
82 |
+
],
|
83 |
+
"outputs": [
|
84 |
+
{
|
85 |
+
"output_type": "stream",
|
86 |
+
"name": "stderr",
|
87 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 04:43:58.191236: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 04:43:59.182154: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 04:43:59.182291: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 04:43:59.182304: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n[codecarbon INFO @ 04:44:02] [setup] RAM Tracking...\n[codecarbon INFO @ 04:44:02] [setup] GPU Tracking...\n[codecarbon INFO @ 04:44:02] Tracking Nvidia GPU via pynvml\n[codecarbon INFO @ 04:44:02] [setup] CPU Tracking...\n[codecarbon WARNING @ 04:44:02] No CPU tracking mode found. Falling back on CPU constant mode.\n[codecarbon WARNING @ 04:44:03] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n[codecarbon INFO @ 04:44:03] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 04:44:03] >>> Tracker's metadata:\n[codecarbon INFO @ 04:44:03] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n[codecarbon INFO @ 04:44:03] Python version: 3.8.5\n[codecarbon INFO @ 04:44:03] CodeCarbon version: 2.3.3\n[codecarbon INFO @ 04:44:03] Available RAM : 440.883 GB\n[codecarbon INFO @ 04:44:03] CPU count: 24\n[codecarbon INFO @ 04:44:03] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n[codecarbon INFO @ 04:44:03] GPU count: 4\n[codecarbon INFO @ 04:44:03] GPU model: 4 x Tesla V100-PCIE-16GB\n[codecarbon WARNING @ 04:44:03] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
|
88 |
+
}
|
89 |
+
],
|
90 |
+
"execution_count": 3,
|
91 |
+
"metadata": {
|
92 |
+
"datalore": {
|
93 |
+
"hide_input_from_viewers": false,
|
94 |
+
"hide_output_from_viewers": false,
|
95 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
96 |
+
"report_properties": {
|
97 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
98 |
+
},
|
99 |
+
"type": "CODE"
|
100 |
+
},
|
101 |
+
"gather": {
|
102 |
+
"logged": 1706503443742
|
103 |
+
},
|
104 |
+
"tags": []
|
105 |
+
}
|
106 |
+
},
|
107 |
+
{
|
108 |
+
"cell_type": "code",
|
109 |
+
"source": [
|
110 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
111 |
+
"\n",
|
112 |
+
"SEED: int = 42\n",
|
113 |
+
"\n",
|
114 |
+
"BATCH_SIZE: int = 32\n",
|
115 |
+
"EPOCHS: int = 5\n",
|
116 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
117 |
+
"\n",
|
118 |
+
"# WandB configuration\n",
|
119 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
120 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
121 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
122 |
+
],
|
123 |
+
"outputs": [],
|
124 |
+
"execution_count": 4,
|
125 |
+
"metadata": {
|
126 |
+
"collapsed": false,
|
127 |
+
"gather": {
|
128 |
+
"logged": 1706503443899
|
129 |
+
},
|
130 |
+
"jupyter": {
|
131 |
+
"outputs_hidden": false
|
132 |
+
}
|
133 |
+
}
|
134 |
+
},
|
135 |
+
{
|
136 |
+
"cell_type": "code",
|
137 |
+
"source": [
|
138 |
+
"%watermark --iversion"
|
139 |
+
],
|
140 |
+
"outputs": [
|
141 |
+
{
|
142 |
+
"output_type": "stream",
|
143 |
+
"name": "stdout",
|
144 |
+
"text": "shap : 0.44.1\nnumpy : 1.23.5\npandas : 2.0.2\nlogging : 0.5.1.2\ntorch : 1.12.0\nevaluate: 0.4.1\nwandb : 0.16.2\nre : 2.2.1\n\n"
|
145 |
+
}
|
146 |
+
],
|
147 |
+
"execution_count": 5,
|
148 |
+
"metadata": {
|
149 |
+
"collapsed": false,
|
150 |
+
"jupyter": {
|
151 |
+
"outputs_hidden": false
|
152 |
+
}
|
153 |
+
}
|
154 |
+
},
|
155 |
+
{
|
156 |
+
"cell_type": "code",
|
157 |
+
"source": [
|
158 |
+
"!nvidia-smi"
|
159 |
+
],
|
160 |
+
"outputs": [
|
161 |
+
{
|
162 |
+
"output_type": "stream",
|
163 |
+
"name": "stdout",
|
164 |
+
"text": "Mon Jan 29 04:44:03 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
165 |
+
}
|
166 |
+
],
|
167 |
+
"execution_count": 6,
|
168 |
+
"metadata": {
|
169 |
+
"datalore": {
|
170 |
+
"hide_input_from_viewers": true,
|
171 |
+
"hide_output_from_viewers": true,
|
172 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
173 |
+
"type": "CODE"
|
174 |
+
}
|
175 |
+
}
|
176 |
+
},
|
177 |
+
{
|
178 |
+
"cell_type": "markdown",
|
179 |
+
"source": [
|
180 |
+
"## Loading the data set"
|
181 |
+
],
|
182 |
+
"metadata": {
|
183 |
+
"datalore": {
|
184 |
+
"hide_input_from_viewers": false,
|
185 |
+
"hide_output_from_viewers": false,
|
186 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
187 |
+
"report_properties": {
|
188 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
189 |
+
},
|
190 |
+
"type": "MD"
|
191 |
+
}
|
192 |
+
}
|
193 |
+
},
|
194 |
+
{
|
195 |
+
"cell_type": "code",
|
196 |
+
"source": [
|
197 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
198 |
+
],
|
199 |
+
"outputs": [],
|
200 |
+
"execution_count": 7,
|
201 |
+
"metadata": {
|
202 |
+
"collapsed": false,
|
203 |
+
"gather": {
|
204 |
+
"logged": 1706503446033
|
205 |
+
},
|
206 |
+
"jupyter": {
|
207 |
+
"outputs_hidden": false
|
208 |
+
}
|
209 |
+
}
|
210 |
+
},
|
211 |
+
{
|
212 |
+
"cell_type": "code",
|
213 |
+
"source": [
|
214 |
+
"dataset"
|
215 |
+
],
|
216 |
+
"outputs": [
|
217 |
+
{
|
218 |
+
"output_type": "execute_result",
|
219 |
+
"execution_count": 8,
|
220 |
+
"data": {
|
221 |
+
"text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
|
222 |
+
},
|
223 |
+
"metadata": {}
|
224 |
+
}
|
225 |
+
],
|
226 |
+
"execution_count": 8,
|
227 |
+
"metadata": {
|
228 |
+
"collapsed": false,
|
229 |
+
"gather": {
|
230 |
+
"logged": 1706503446252
|
231 |
+
},
|
232 |
+
"jupyter": {
|
233 |
+
"outputs_hidden": false,
|
234 |
+
"source_hidden": false
|
235 |
+
},
|
236 |
+
"nteract": {
|
237 |
+
"transient": {
|
238 |
+
"deleting": false
|
239 |
+
}
|
240 |
+
}
|
241 |
+
}
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"source": [
|
246 |
+
"SUBSAMPLING = 1.0\n",
|
247 |
+
"\n",
|
248 |
+
"if SUBSAMPLING < 1:\n",
|
249 |
+
" _ = DatasetDict()\n",
|
250 |
+
" for each in dataset.keys():\n",
|
251 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
252 |
+
"\n",
|
253 |
+
" dataset = _"
|
254 |
+
],
|
255 |
+
"outputs": [],
|
256 |
+
"execution_count": 9,
|
257 |
+
"metadata": {
|
258 |
+
"gather": {
|
259 |
+
"logged": 1706503446498
|
260 |
+
}
|
261 |
+
}
|
262 |
+
},
|
263 |
+
{
|
264 |
+
"cell_type": "markdown",
|
265 |
+
"source": [
|
266 |
+
"## Tokenisation and encoding"
|
267 |
+
],
|
268 |
+
"metadata": {}
|
269 |
+
},
|
270 |
+
{
|
271 |
+
"cell_type": "code",
|
272 |
+
"source": [
|
273 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
274 |
+
" return ds_enc"
|
275 |
+
],
|
276 |
+
"outputs": [],
|
277 |
+
"execution_count": 10,
|
278 |
+
"metadata": {
|
279 |
+
"gather": {
|
280 |
+
"logged": 1706503446633
|
281 |
+
}
|
282 |
+
}
|
283 |
+
},
|
284 |
+
{
|
285 |
+
"cell_type": "markdown",
|
286 |
+
"source": [
|
287 |
+
"## Evaluation metrics"
|
288 |
+
],
|
289 |
+
"metadata": {}
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"source": [
|
294 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
295 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
296 |
+
"f1 = evaluate.load(\"f1\")"
|
297 |
+
],
|
298 |
+
"outputs": [],
|
299 |
+
"execution_count": 11,
|
300 |
+
"metadata": {
|
301 |
+
"gather": {
|
302 |
+
"logged": 1706503446863
|
303 |
+
}
|
304 |
+
}
|
305 |
+
},
|
306 |
+
{
|
307 |
+
"cell_type": "code",
|
308 |
+
"source": [
|
309 |
+
"def compute_metrics(eval_pred):\n",
|
310 |
+
" predictions, labels = eval_pred\n",
|
311 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
312 |
+
" return {\n",
|
313 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
314 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
315 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
316 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
317 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
318 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
319 |
+
" }"
|
320 |
+
],
|
321 |
+
"outputs": [],
|
322 |
+
"execution_count": 12,
|
323 |
+
"metadata": {
|
324 |
+
"gather": {
|
325 |
+
"logged": 1706503447004
|
326 |
+
}
|
327 |
+
}
|
328 |
+
},
|
329 |
+
{
|
330 |
+
"cell_type": "markdown",
|
331 |
+
"source": [
|
332 |
+
"## Training"
|
333 |
+
],
|
334 |
+
"metadata": {}
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"cell_type": "markdown",
|
338 |
+
"source": [
|
339 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
340 |
+
],
|
341 |
+
"metadata": {}
|
342 |
+
},
|
343 |
+
{
|
344 |
+
"cell_type": "code",
|
345 |
+
"source": [
|
346 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
347 |
+
],
|
348 |
+
"outputs": [],
|
349 |
+
"execution_count": 13,
|
350 |
+
"metadata": {
|
351 |
+
"gather": {
|
352 |
+
"logged": 1706503447186
|
353 |
+
}
|
354 |
+
}
|
355 |
+
},
|
356 |
+
{
|
357 |
+
"cell_type": "code",
|
358 |
+
"source": [
|
359 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
360 |
+
"\n",
|
361 |
+
"cols = dataset[\"train\"].column_names\n",
|
362 |
+
"cols.remove(\"label\")\n",
|
363 |
+
"ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n"
|
364 |
+
],
|
365 |
+
"outputs": [
|
366 |
+
{
|
367 |
+
"output_type": "stream",
|
368 |
+
"name": "stderr",
|
369 |
+
"text": "Map: 100%|██████████| 272238/272238 [01:45<00:00, 2592.04 examples/s]\n"
|
370 |
+
}
|
371 |
+
],
|
372 |
+
"execution_count": 14,
|
373 |
+
"metadata": {
|
374 |
+
"gather": {
|
375 |
+
"logged": 1706503552083
|
376 |
+
}
|
377 |
+
}
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"cell_type": "code",
|
381 |
+
"source": [
|
382 |
+
"\n",
|
383 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
384 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
385 |
+
" id2label=label_map, \n",
|
386 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
387 |
+
"\n",
|
388 |
+
"args = TrainingArguments(\n",
|
389 |
+
" output_dir=\"vaers\",\n",
|
390 |
+
" evaluation_strategy=\"epoch\",\n",
|
391 |
+
" save_strategy=\"epoch\",\n",
|
392 |
+
" learning_rate=2e-5,\n",
|
393 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
394 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
395 |
+
" num_train_epochs=EPOCHS,\n",
|
396 |
+
" weight_decay=.01,\n",
|
397 |
+
" logging_steps=1,\n",
|
398 |
+
" load_best_model_at_end=True,\n",
|
399 |
+
" run_name=f\"daedra-training\",\n",
|
400 |
+
" report_to=[\"wandb\"])\n",
|
401 |
+
"\n",
|
402 |
+
"trainer = Trainer(\n",
|
403 |
+
" model=model,\n",
|
404 |
+
" args=args,\n",
|
405 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
406 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
407 |
+
" tokenizer=tokenizer,\n",
|
408 |
+
" compute_metrics=compute_metrics)"
|
409 |
+
],
|
410 |
+
"outputs": [
|
411 |
+
{
|
412 |
+
"output_type": "stream",
|
413 |
+
"name": "stderr",
|
414 |
+
"text": "Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\nYou should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
415 |
+
}
|
416 |
+
],
|
417 |
+
"execution_count": 15,
|
418 |
+
"metadata": {
|
419 |
+
"gather": {
|
420 |
+
"logged": 1706503554669
|
421 |
+
}
|
422 |
+
}
|
423 |
+
},
|
424 |
+
{
|
425 |
+
"cell_type": "code",
|
426 |
+
"source": [
|
427 |
+
"if SUBSAMPLING != 1.0:\n",
|
428 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
429 |
+
"else:\n",
|
430 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
431 |
+
"\n",
|
432 |
+
"wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
433 |
+
"wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
434 |
+
" \n",
|
435 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
|
436 |
+
],
|
437 |
+
"outputs": [
|
438 |
+
{
|
439 |
+
"output_type": "stream",
|
440 |
+
"name": "stderr",
|
441 |
+
"text": "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
442 |
+
},
|
443 |
+
{
|
444 |
+
"output_type": "display_data",
|
445 |
+
"data": {
|
446 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
447 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
448 |
+
},
|
449 |
+
"metadata": {}
|
450 |
+
},
|
451 |
+
{
|
452 |
+
"output_type": "display_data",
|
453 |
+
"data": {
|
454 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
455 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_044555-kjhyoltp</code>"
|
456 |
+
},
|
457 |
+
"metadata": {}
|
458 |
+
},
|
459 |
+
{
|
460 |
+
"output_type": "display_data",
|
461 |
+
"data": {
|
462 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
463 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
464 |
+
},
|
465 |
+
"metadata": {}
|
466 |
+
},
|
467 |
+
{
|
468 |
+
"output_type": "display_data",
|
469 |
+
"data": {
|
470 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
471 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
472 |
+
},
|
473 |
+
"metadata": {}
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"output_type": "display_data",
|
477 |
+
"data": {
|
478 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
479 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp</a>"
|
480 |
+
},
|
481 |
+
"metadata": {}
|
482 |
+
},
|
483 |
+
{
|
484 |
+
"output_type": "display_data",
|
485 |
+
"data": {
|
486 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
487 |
+
"text/html": "Finishing last run (ID:kjhyoltp) before initializing another..."
|
488 |
+
},
|
489 |
+
"metadata": {}
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"output_type": "display_data",
|
493 |
+
"data": {
|
494 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
495 |
+
"text/html": "W&B sync reduced upload amount by 26.5% "
|
496 |
+
},
|
497 |
+
"metadata": {}
|
498 |
+
},
|
499 |
+
{
|
500 |
+
"output_type": "display_data",
|
501 |
+
"data": {
|
502 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
503 |
+
"text/html": " View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kjhyoltp</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v1' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v1</a><br/>Synced 5 W&B file(s), 0 media file(s), 2 artifact file(s) and 0 other file(s)"
|
504 |
+
},
|
505 |
+
"metadata": {}
|
506 |
+
},
|
507 |
+
{
|
508 |
+
"output_type": "display_data",
|
509 |
+
"data": {
|
510 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
511 |
+
"text/html": "Find logs at: <code>./wandb/run-20240129_044555-kjhyoltp/logs</code>"
|
512 |
+
},
|
513 |
+
"metadata": {}
|
514 |
+
},
|
515 |
+
{
|
516 |
+
"output_type": "display_data",
|
517 |
+
"data": {
|
518 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
519 |
+
"text/html": "Successfully finished last run (ID:kjhyoltp). Initializing new run:<br/>"
|
520 |
+
},
|
521 |
+
"metadata": {}
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"output_type": "display_data",
|
525 |
+
"data": {
|
526 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
527 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
528 |
+
},
|
529 |
+
"metadata": {}
|
530 |
+
},
|
531 |
+
{
|
532 |
+
"output_type": "display_data",
|
533 |
+
"data": {
|
534 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
535 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_044558-ed51hqn6</code>"
|
536 |
+
},
|
537 |
+
"metadata": {}
|
538 |
+
},
|
539 |
+
{
|
540 |
+
"output_type": "display_data",
|
541 |
+
"data": {
|
542 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
543 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
544 |
+
},
|
545 |
+
"metadata": {}
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"output_type": "display_data",
|
549 |
+
"data": {
|
550 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
551 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
552 |
+
},
|
553 |
+
"metadata": {}
|
554 |
+
},
|
555 |
+
{
|
556 |
+
"output_type": "display_data",
|
557 |
+
"data": {
|
558 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
559 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6</a>"
|
560 |
+
},
|
561 |
+
"metadata": {}
|
562 |
+
},
|
563 |
+
{
|
564 |
+
"output_type": "execute_result",
|
565 |
+
"execution_count": 16,
|
566 |
+
"data": {
|
567 |
+
"text/html": "<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/ed51hqn6?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>",
|
568 |
+
"text/plain": "<wandb.sdk.wandb_run.Run at 0x7f19d09fdbe0>"
|
569 |
+
},
|
570 |
+
"metadata": {}
|
571 |
+
}
|
572 |
+
],
|
573 |
+
"execution_count": 16,
|
574 |
+
"metadata": {
|
575 |
+
"gather": {
|
576 |
+
"logged": 1706503566090
|
577 |
+
}
|
578 |
+
}
|
579 |
+
},
|
580 |
+
{
|
581 |
+
"cell_type": "code",
|
582 |
+
"source": [
|
583 |
+
"tracker.start()\n",
|
584 |
+
"trainer.train()\n",
|
585 |
+
"tracker.stop()\n"
|
586 |
+
],
|
587 |
+
"outputs": [
|
588 |
+
{
|
589 |
+
"output_type": "stream",
|
590 |
+
"name": "stderr",
|
591 |
+
"text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
592 |
+
},
|
593 |
+
{
|
594 |
+
"output_type": "display_data",
|
595 |
+
"data": {
|
596 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
597 |
+
"text/html": "\n <div>\n \n <progress value='183' max='49630' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 183/49630 01:56 < 8:51:06, 1.55 it/s, Epoch 0.02/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
|
598 |
+
},
|
599 |
+
"metadata": {}
|
600 |
+
},
|
601 |
+
{
|
602 |
+
"output_type": "stream",
|
603 |
+
"name": "stderr",
|
604 |
+
"text": "[codecarbon INFO @ 04:46:20] Energy consumed for RAM : 0.000690 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:46:20] Energy consumed for all GPUs : 0.001499 kWh. Total GPU Power : 359.1829830586385 W\n[codecarbon INFO @ 04:46:20] Energy consumed for all CPUs : 0.000177 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:46:20] 0.002366 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:46:35] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:46:35] Energy consumed for all GPUs : 0.004078 kWh. Total GPU Power : 619.6193403526773 W\n[codecarbon INFO @ 04:46:35] Energy consumed for all CPUs : 0.000355 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:46:35] 0.005811 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:46:50] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:46:50] Energy consumed for all GPUs : 0.006632 kWh. Total GPU Power : 613.6554096062922 W\n[codecarbon INFO @ 04:46:50] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:46:50] 0.009230 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:05] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:05] Energy consumed for all GPUs : 0.009249 kWh. Total GPU Power : 628.5574609453653 W\n[codecarbon INFO @ 04:47:05] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:05] 0.012712 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:20] Energy consumed for RAM : 0.003442 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:20] Energy consumed for all GPUs : 0.011850 kWh. Total GPU Power : 624.8454173521444 W\n[codecarbon INFO @ 04:47:20] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:20] 0.016178 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:35] Energy consumed for RAM : 0.004130 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:35] Energy consumed for all GPUs : 0.014490 kWh. Total GPU Power : 634.7378588005432 W\n[codecarbon INFO @ 04:47:35] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:35] 0.019683 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:47:50] Energy consumed for RAM : 0.004818 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:47:50] Energy consumed for all GPUs : 0.017140 kWh. Total GPU Power : 636.6500188212152 W\n[codecarbon INFO @ 04:47:50] Energy consumed for all CPUs : 0.001240 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:47:50] 0.023197 kWh of electricity used since the beginning.\n[codecarbon INFO @ 04:48:05] Energy consumed for RAM : 0.005506 kWh. RAM Power : 165.33123922348022 W\n[codecarbon INFO @ 04:48:05] Energy consumed for all GPUs : 0.019771 kWh. Total GPU Power : 631.881788173399 W\n[codecarbon INFO @ 04:48:05] Energy consumed for all CPUs : 0.001417 kWh. Total CPU Power : 42.5 W\n[codecarbon INFO @ 04:48:05] 0.026694 kWh of electricity used since the beginning.\n"
|
605 |
+
}
|
606 |
+
],
|
607 |
+
"execution_count": 17,
|
608 |
+
"metadata": {
|
609 |
+
"gather": {
|
610 |
+
"logged": 1706486541798
|
611 |
+
}
|
612 |
+
}
|
613 |
+
},
|
614 |
+
{
|
615 |
+
"cell_type": "code",
|
616 |
+
"source": [
|
617 |
+
"wandb.finish()"
|
618 |
+
],
|
619 |
+
"outputs": [],
|
620 |
+
"execution_count": null,
|
621 |
+
"metadata": {
|
622 |
+
"gather": {
|
623 |
+
"logged": 1706486541918
|
624 |
+
}
|
625 |
+
}
|
626 |
+
},
|
627 |
+
{
|
628 |
+
"cell_type": "code",
|
629 |
+
"source": [
|
630 |
+
"variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
631 |
+
"tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
632 |
+
"tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
633 |
+
"sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
634 |
+
"\n",
|
635 |
+
"model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
636 |
+
" variant=variant,\n",
|
637 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
|
638 |
+
],
|
639 |
+
"outputs": [],
|
640 |
+
"execution_count": null,
|
641 |
+
"metadata": {
|
642 |
+
"gather": {
|
643 |
+
"logged": 1706486541928
|
644 |
+
}
|
645 |
+
}
|
646 |
+
},
|
647 |
+
{
|
648 |
+
"cell_type": "code",
|
649 |
+
"source": [
|
650 |
+
"variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
651 |
+
"tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
652 |
+
"tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
653 |
+
"sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
654 |
+
"\n",
|
655 |
+
"model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
656 |
+
" variant=variant,\n",
|
657 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
|
658 |
+
],
|
659 |
+
"outputs": [],
|
660 |
+
"execution_count": null,
|
661 |
+
"metadata": {}
|
662 |
+
}
|
663 |
+
],
|
664 |
+
"metadata": {
|
665 |
+
"datalore": {
|
666 |
+
"base_environment": "default",
|
667 |
+
"computation_mode": "JUPYTER",
|
668 |
+
"package_manager": "pip",
|
669 |
+
"packages": [
|
670 |
+
{
|
671 |
+
"name": "datasets",
|
672 |
+
"source": "PIP",
|
673 |
+
"version": "2.16.1"
|
674 |
+
},
|
675 |
+
{
|
676 |
+
"name": "torch",
|
677 |
+
"source": "PIP",
|
678 |
+
"version": "2.1.2"
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"name": "accelerate",
|
682 |
+
"source": "PIP",
|
683 |
+
"version": "0.26.1"
|
684 |
+
}
|
685 |
+
],
|
686 |
+
"report_row_ids": [
|
687 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
688 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
689 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
690 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
691 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
692 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
693 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
694 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
695 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
696 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
697 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
698 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
699 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
700 |
+
],
|
701 |
+
"version": 3
|
702 |
+
},
|
703 |
+
"kernel_info": {
|
704 |
+
"name": "python38-azureml-pt-tf"
|
705 |
+
},
|
706 |
+
"kernelspec": {
|
707 |
+
"display_name": "azureml_py38_PT_TF",
|
708 |
+
"language": "python",
|
709 |
+
"name": "python3"
|
710 |
+
},
|
711 |
+
"language_info": {
|
712 |
+
"name": "python",
|
713 |
+
"version": "3.8.5",
|
714 |
+
"mimetype": "text/x-python",
|
715 |
+
"codemirror_mode": {
|
716 |
+
"name": "ipython",
|
717 |
+
"version": 3
|
718 |
+
},
|
719 |
+
"pygments_lexer": "ipython3",
|
720 |
+
"nbconvert_exporter": "python",
|
721 |
+
"file_extension": ".py"
|
722 |
+
},
|
723 |
+
"microsoft": {
|
724 |
+
"host": {
|
725 |
+
"AzureML": {
|
726 |
+
"notebookHasBeenCompleted": true
|
727 |
+
}
|
728 |
+
},
|
729 |
+
"ms_spell_check": {
|
730 |
+
"ms_spell_check_language": "en"
|
731 |
+
}
|
732 |
+
},
|
733 |
+
"nteract": {
|
734 |
+
"version": "nteract-front-end@1.0.0"
|
735 |
+
}
|
736 |
+
},
|
737 |
+
"nbformat": 4,
|
738 |
+
"nbformat_minor": 4
|
739 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-16-5-15Z.ipynb
ADDED
@@ -0,0 +1,729 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
8 |
+
"\n",
|
9 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"gather": {
|
17 |
+
"logged": 1706475754655
|
18 |
+
},
|
19 |
+
"nteract": {
|
20 |
+
"transient": {
|
21 |
+
"deleting": false
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"tags": []
|
25 |
+
},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
32 |
+
"Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
|
33 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
|
34 |
+
"Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
|
35 |
+
"Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
|
36 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
|
37 |
+
"Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
|
38 |
+
"Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
|
39 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
|
40 |
+
"Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
|
41 |
+
"Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
|
42 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
|
43 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
|
44 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
|
45 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
|
46 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
|
47 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
|
48 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
|
49 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
50 |
+
]
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"%pip install accelerate -U"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 2,
|
60 |
+
"metadata": {
|
61 |
+
"nteract": {
|
62 |
+
"transient": {
|
63 |
+
"deleting": false
|
64 |
+
}
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"outputs": [
|
68 |
+
{
|
69 |
+
"name": "stdout",
|
70 |
+
"output_type": "stream",
|
71 |
+
"text": [
|
72 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
73 |
+
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
|
74 |
+
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
|
75 |
+
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
|
76 |
+
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
|
77 |
+
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
|
78 |
+
"Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
|
79 |
+
"Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
|
80 |
+
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
|
81 |
+
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
|
82 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
|
83 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
|
84 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
|
85 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
|
86 |
+
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
|
87 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
|
88 |
+
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
|
89 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
|
90 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
|
91 |
+
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
|
92 |
+
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
|
93 |
+
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
|
94 |
+
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
|
95 |
+
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
|
96 |
+
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
|
97 |
+
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
|
98 |
+
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
|
99 |
+
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
|
100 |
+
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
|
101 |
+
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
|
102 |
+
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
|
103 |
+
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
|
104 |
+
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
|
105 |
+
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
|
106 |
+
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
|
107 |
+
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
|
108 |
+
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
|
109 |
+
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
|
110 |
+
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
|
111 |
+
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
|
112 |
+
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
|
113 |
+
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
|
114 |
+
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
|
115 |
+
"Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
|
116 |
+
"Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
|
117 |
+
"Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
|
118 |
+
"Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
|
119 |
+
"Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
|
120 |
+
"Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
|
121 |
+
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
|
122 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
|
123 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
|
124 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
125 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
|
126 |
+
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
127 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
128 |
+
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
|
129 |
+
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
|
130 |
+
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
|
131 |
+
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
|
132 |
+
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
|
133 |
+
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
|
134 |
+
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
|
135 |
+
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
|
136 |
+
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
|
137 |
+
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
|
138 |
+
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
|
139 |
+
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
|
140 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
|
141 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
|
142 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
|
143 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
|
144 |
+
"Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
|
145 |
+
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
|
146 |
+
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
|
147 |
+
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
148 |
+
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
149 |
+
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
|
150 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
|
151 |
+
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
|
152 |
+
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
|
153 |
+
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
|
154 |
+
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
|
155 |
+
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
|
156 |
+
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
|
157 |
+
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
|
158 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
159 |
+
]
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"source": [
|
163 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": 28,
|
169 |
+
"metadata": {
|
170 |
+
"datalore": {
|
171 |
+
"hide_input_from_viewers": false,
|
172 |
+
"hide_output_from_viewers": false,
|
173 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
174 |
+
"report_properties": {
|
175 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
176 |
+
},
|
177 |
+
"type": "CODE"
|
178 |
+
},
|
179 |
+
"gather": {
|
180 |
+
"logged": 1706503443742
|
181 |
+
},
|
182 |
+
"tags": []
|
183 |
+
},
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"data": {
|
187 |
+
"text/html": [
|
188 |
+
" View run <strong style=\"color:#cdcd00\">daedra_0.05-distilbert-base-uncased</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
189 |
+
],
|
190 |
+
"text/plain": [
|
191 |
+
"<IPython.core.display.HTML object>"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
"metadata": {},
|
195 |
+
"output_type": "display_data"
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"data": {
|
199 |
+
"text/html": [
|
200 |
+
"Find logs at: <code>./wandb/run-20240129_152136-cwkdl3x7/logs</code>"
|
201 |
+
],
|
202 |
+
"text/plain": [
|
203 |
+
"<IPython.core.display.HTML object>"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
"metadata": {},
|
207 |
+
"output_type": "display_data"
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"name": "stdout",
|
211 |
+
"output_type": "stream",
|
212 |
+
"text": [
|
213 |
+
"The watermark extension is already loaded. To reload it, use:\n",
|
214 |
+
" %reload_ext watermark\n"
|
215 |
+
]
|
216 |
+
}
|
217 |
+
],
|
218 |
+
"source": [
|
219 |
+
"import pandas as pd\n",
|
220 |
+
"import numpy as np\n",
|
221 |
+
"import torch\n",
|
222 |
+
"import os\n",
|
223 |
+
"from typing import List, Union\n",
|
224 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
|
225 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
226 |
+
"import shap\n",
|
227 |
+
"import wandb\n",
|
228 |
+
"import evaluate\n",
|
229 |
+
"import logging\n",
|
230 |
+
"\n",
|
231 |
+
"wandb.finish()\n",
|
232 |
+
"\n",
|
233 |
+
"\n",
|
234 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
235 |
+
"\n",
|
236 |
+
"%load_ext watermark"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "code",
|
241 |
+
"execution_count": 4,
|
242 |
+
"metadata": {
|
243 |
+
"collapsed": false,
|
244 |
+
"gather": {
|
245 |
+
"logged": 1706503443899
|
246 |
+
},
|
247 |
+
"jupyter": {
|
248 |
+
"outputs_hidden": false
|
249 |
+
}
|
250 |
+
},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
254 |
+
"\n",
|
255 |
+
"SEED: int = 42\n",
|
256 |
+
"\n",
|
257 |
+
"BATCH_SIZE: int = 32\n",
|
258 |
+
"EPOCHS: int = 5\n",
|
259 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
260 |
+
"\n",
|
261 |
+
"# WandB configuration\n",
|
262 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
263 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
264 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"execution_count": 5,
|
270 |
+
"metadata": {
|
271 |
+
"collapsed": false,
|
272 |
+
"jupyter": {
|
273 |
+
"outputs_hidden": false
|
274 |
+
}
|
275 |
+
},
|
276 |
+
"outputs": [
|
277 |
+
{
|
278 |
+
"name": "stdout",
|
279 |
+
"output_type": "stream",
|
280 |
+
"text": [
|
281 |
+
"re : 2.2.1\n",
|
282 |
+
"torch : 1.12.0\n",
|
283 |
+
"wandb : 0.16.2\n",
|
284 |
+
"logging : 0.5.1.2\n",
|
285 |
+
"numpy : 1.23.5\n",
|
286 |
+
"pandas : 2.0.2\n",
|
287 |
+
"evaluate: 0.4.1\n",
|
288 |
+
"shap : 0.44.1\n",
|
289 |
+
"\n"
|
290 |
+
]
|
291 |
+
}
|
292 |
+
],
|
293 |
+
"source": [
|
294 |
+
"%watermark --iversion"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "code",
|
299 |
+
"execution_count": 6,
|
300 |
+
"metadata": {
|
301 |
+
"datalore": {
|
302 |
+
"hide_input_from_viewers": true,
|
303 |
+
"hide_output_from_viewers": true,
|
304 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
305 |
+
"type": "CODE"
|
306 |
+
}
|
307 |
+
},
|
308 |
+
"outputs": [
|
309 |
+
{
|
310 |
+
"name": "stdout",
|
311 |
+
"output_type": "stream",
|
312 |
+
"text": [
|
313 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
314 |
+
"Mon Jan 29 15:20:22 2024 \n",
|
315 |
+
"+---------------------------------------------------------------------------------------+\n",
|
316 |
+
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
|
317 |
+
"|-----------------------------------------+----------------------+----------------------+\n",
|
318 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
319 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
320 |
+
"| | | MIG M. |\n",
|
321 |
+
"|=========================================+======================+======================|\n",
|
322 |
+
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
|
323 |
+
"| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
324 |
+
"| | | N/A |\n",
|
325 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
326 |
+
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
|
327 |
+
"| N/A 26C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
328 |
+
"| | | N/A |\n",
|
329 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
330 |
+
"| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\n",
|
331 |
+
"| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
332 |
+
"| | | N/A |\n",
|
333 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
334 |
+
"| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\n",
|
335 |
+
"| N/A 28C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
336 |
+
"| | | N/A |\n",
|
337 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
338 |
+
" \n",
|
339 |
+
"+---------------------------------------------------------------------------------------+\n",
|
340 |
+
"| Processes: |\n",
|
341 |
+
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
342 |
+
"| ID ID Usage |\n",
|
343 |
+
"|=======================================================================================|\n",
|
344 |
+
"| No running processes found |\n",
|
345 |
+
"+---------------------------------------------------------------------------------------+\n"
|
346 |
+
]
|
347 |
+
}
|
348 |
+
],
|
349 |
+
"source": [
|
350 |
+
"!nvidia-smi"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "markdown",
|
355 |
+
"metadata": {
|
356 |
+
"datalore": {
|
357 |
+
"hide_input_from_viewers": false,
|
358 |
+
"hide_output_from_viewers": false,
|
359 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
360 |
+
"report_properties": {
|
361 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
362 |
+
},
|
363 |
+
"type": "MD"
|
364 |
+
}
|
365 |
+
},
|
366 |
+
"source": [
|
367 |
+
"## Loading the data set"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": 7,
|
373 |
+
"metadata": {
|
374 |
+
"collapsed": false,
|
375 |
+
"gather": {
|
376 |
+
"logged": 1706503446033
|
377 |
+
},
|
378 |
+
"jupyter": {
|
379 |
+
"outputs_hidden": false
|
380 |
+
}
|
381 |
+
},
|
382 |
+
"outputs": [],
|
383 |
+
"source": [
|
384 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
385 |
+
]
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"cell_type": "code",
|
389 |
+
"execution_count": 8,
|
390 |
+
"metadata": {
|
391 |
+
"collapsed": false,
|
392 |
+
"gather": {
|
393 |
+
"logged": 1706503446252
|
394 |
+
},
|
395 |
+
"jupyter": {
|
396 |
+
"outputs_hidden": false,
|
397 |
+
"source_hidden": false
|
398 |
+
},
|
399 |
+
"nteract": {
|
400 |
+
"transient": {
|
401 |
+
"deleting": false
|
402 |
+
}
|
403 |
+
}
|
404 |
+
},
|
405 |
+
"outputs": [
|
406 |
+
{
|
407 |
+
"data": {
|
408 |
+
"text/plain": [
|
409 |
+
"DatasetDict({\n",
|
410 |
+
" train: Dataset({\n",
|
411 |
+
" features: ['id', 'text', 'label'],\n",
|
412 |
+
" num_rows: 1270444\n",
|
413 |
+
" })\n",
|
414 |
+
" test: Dataset({\n",
|
415 |
+
" features: ['id', 'text', 'label'],\n",
|
416 |
+
" num_rows: 272238\n",
|
417 |
+
" })\n",
|
418 |
+
" val: Dataset({\n",
|
419 |
+
" features: ['id', 'text', 'label'],\n",
|
420 |
+
" num_rows: 272238\n",
|
421 |
+
" })\n",
|
422 |
+
"})"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
"execution_count": 8,
|
426 |
+
"metadata": {},
|
427 |
+
"output_type": "execute_result"
|
428 |
+
}
|
429 |
+
],
|
430 |
+
"source": [
|
431 |
+
"dataset"
|
432 |
+
]
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"cell_type": "code",
|
436 |
+
"execution_count": 8,
|
437 |
+
"metadata": {
|
438 |
+
"gather": {
|
439 |
+
"logged": 1706503446498
|
440 |
+
}
|
441 |
+
},
|
442 |
+
"outputs": [],
|
443 |
+
"source": [
|
444 |
+
"SUBSAMPLING = 0.01\n",
|
445 |
+
"\n",
|
446 |
+
"if SUBSAMPLING < 1:\n",
|
447 |
+
" _ = DatasetDict()\n",
|
448 |
+
" for each in dataset.keys():\n",
|
449 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
450 |
+
"\n",
|
451 |
+
" dataset = _"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "markdown",
|
456 |
+
"metadata": {},
|
457 |
+
"source": [
|
458 |
+
"## Tokenisation and encoding"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "code",
|
463 |
+
"execution_count": 10,
|
464 |
+
"metadata": {
|
465 |
+
"gather": {
|
466 |
+
"logged": 1706503446633
|
467 |
+
}
|
468 |
+
},
|
469 |
+
"outputs": [],
|
470 |
+
"source": [
|
471 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
472 |
+
" return ds_enc"
|
473 |
+
]
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"cell_type": "markdown",
|
477 |
+
"metadata": {},
|
478 |
+
"source": [
|
479 |
+
"## Evaluation metrics"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
{
|
483 |
+
"cell_type": "code",
|
484 |
+
"execution_count": 9,
|
485 |
+
"metadata": {
|
486 |
+
"gather": {
|
487 |
+
"logged": 1706503446863
|
488 |
+
}
|
489 |
+
},
|
490 |
+
"outputs": [],
|
491 |
+
"source": [
|
492 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
493 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
494 |
+
"f1 = evaluate.load(\"f1\")"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": 10,
|
500 |
+
"metadata": {
|
501 |
+
"gather": {
|
502 |
+
"logged": 1706503447004
|
503 |
+
}
|
504 |
+
},
|
505 |
+
"outputs": [],
|
506 |
+
"source": [
|
507 |
+
"def compute_metrics(eval_pred):\n",
|
508 |
+
" predictions, labels = eval_pred\n",
|
509 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
510 |
+
" return {\n",
|
511 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
512 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
513 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
514 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
515 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
516 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
517 |
+
" }"
|
518 |
+
]
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"cell_type": "markdown",
|
522 |
+
"metadata": {},
|
523 |
+
"source": [
|
524 |
+
"## Training"
|
525 |
+
]
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"cell_type": "markdown",
|
529 |
+
"metadata": {},
|
530 |
+
"source": [
|
531 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
532 |
+
]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"cell_type": "code",
|
536 |
+
"execution_count": 11,
|
537 |
+
"metadata": {
|
538 |
+
"gather": {
|
539 |
+
"logged": 1706503447186
|
540 |
+
}
|
541 |
+
},
|
542 |
+
"outputs": [],
|
543 |
+
"source": [
|
544 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
545 |
+
]
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"cell_type": "code",
|
549 |
+
"execution_count": 12,
|
550 |
+
"metadata": {
|
551 |
+
"jupyter": {
|
552 |
+
"outputs_hidden": false,
|
553 |
+
"source_hidden": false
|
554 |
+
},
|
555 |
+
"nteract": {
|
556 |
+
"transient": {
|
557 |
+
"deleting": false
|
558 |
+
}
|
559 |
+
}
|
560 |
+
},
|
561 |
+
"outputs": [],
|
562 |
+
"source": [
|
563 |
+
"def train_from_model(model_ckpt: str, push: bool = False):\n",
|
564 |
+
" print(f\"Initialising training based on {model_ckpt}...\")\n",
|
565 |
+
"\n",
|
566 |
+
" print(\"Tokenising...\")\n",
|
567 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
568 |
+
"\n",
|
569 |
+
" cols = dataset[\"train\"].column_names\n",
|
570 |
+
" cols.remove(\"label\")\n",
|
571 |
+
" ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
|
572 |
+
"\n",
|
573 |
+
" print(\"Loading model...\")\n",
|
574 |
+
" try:\n",
|
575 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
576 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
577 |
+
" id2label=label_map, \n",
|
578 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
579 |
+
" except OSError:\n",
|
580 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
581 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
582 |
+
" id2label=label_map, \n",
|
583 |
+
" label2id={v:k for k,v in label_map.items()},\n",
|
584 |
+
" from_tf=True)\n",
|
585 |
+
"\n",
|
586 |
+
"\n",
|
587 |
+
" args = TrainingArguments(\n",
|
588 |
+
" output_dir=\"vaers\",\n",
|
589 |
+
" evaluation_strategy=\"epoch\",\n",
|
590 |
+
" save_strategy=\"epoch\",\n",
|
591 |
+
" learning_rate=2e-5,\n",
|
592 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
593 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
594 |
+
" num_train_epochs=EPOCHS,\n",
|
595 |
+
" weight_decay=.01,\n",
|
596 |
+
" logging_steps=1,\n",
|
597 |
+
" load_best_model_at_end=True,\n",
|
598 |
+
" run_name=f\"daedra-training\",\n",
|
599 |
+
" report_to=[\"wandb\"])\n",
|
600 |
+
"\n",
|
601 |
+
" trainer = Trainer(\n",
|
602 |
+
" model=model,\n",
|
603 |
+
" args=args,\n",
|
604 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
605 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
606 |
+
" tokenizer=tokenizer,\n",
|
607 |
+
" compute_metrics=compute_metrics)\n",
|
608 |
+
" \n",
|
609 |
+
" if SUBSAMPLING != 1.0:\n",
|
610 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
611 |
+
" else:\n",
|
612 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
613 |
+
"\n",
|
614 |
+
" wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
615 |
+
" wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
616 |
+
" \n",
|
617 |
+
" wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
|
618 |
+
"\n",
|
619 |
+
" print(\"Starting training...\")\n",
|
620 |
+
"\n",
|
621 |
+
" trainer.train()\n",
|
622 |
+
"\n",
|
623 |
+
" print(\"Training finished.\")\n",
|
624 |
+
"\n",
|
625 |
+
" if push:\n",
|
626 |
+
" variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
627 |
+
" tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
628 |
+
" tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
629 |
+
" sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
630 |
+
"\n",
|
631 |
+
" model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
632 |
+
" variant=variant,\n",
|
633 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
|
634 |
+
]
|
635 |
+
},
|
636 |
+
{
|
637 |
+
"cell_type": "code",
|
638 |
+
"execution_count": 13,
|
639 |
+
"metadata": {
|
640 |
+
"gather": {
|
641 |
+
"logged": 1706503552083
|
642 |
+
}
|
643 |
+
},
|
644 |
+
"outputs": [],
|
645 |
+
"source": [
|
646 |
+
"\n",
|
647 |
+
"base_models = [\n",
|
648 |
+
" \"bert-base-uncased\",\n",
|
649 |
+
" \"distilbert-base-uncased\",\n",
|
650 |
+
"]"
|
651 |
+
]
|
652 |
+
}
|
653 |
+
],
|
654 |
+
"metadata": {
|
655 |
+
"datalore": {
|
656 |
+
"base_environment": "default",
|
657 |
+
"computation_mode": "JUPYTER",
|
658 |
+
"package_manager": "pip",
|
659 |
+
"packages": [
|
660 |
+
{
|
661 |
+
"name": "datasets",
|
662 |
+
"source": "PIP",
|
663 |
+
"version": "2.16.1"
|
664 |
+
},
|
665 |
+
{
|
666 |
+
"name": "torch",
|
667 |
+
"source": "PIP",
|
668 |
+
"version": "2.1.2"
|
669 |
+
},
|
670 |
+
{
|
671 |
+
"name": "accelerate",
|
672 |
+
"source": "PIP",
|
673 |
+
"version": "0.26.1"
|
674 |
+
}
|
675 |
+
],
|
676 |
+
"report_row_ids": [
|
677 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
678 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
679 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
680 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
681 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
682 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
683 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
684 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
685 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
686 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
687 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
688 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
689 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
690 |
+
],
|
691 |
+
"version": 3
|
692 |
+
},
|
693 |
+
"kernel_info": {
|
694 |
+
"name": "python38-azureml-pt-tf"
|
695 |
+
},
|
696 |
+
"kernelspec": {
|
697 |
+
"display_name": "azureml_py38_PT_TF",
|
698 |
+
"language": "python",
|
699 |
+
"name": "python3"
|
700 |
+
},
|
701 |
+
"language_info": {
|
702 |
+
"codemirror_mode": {
|
703 |
+
"name": "ipython",
|
704 |
+
"version": 3
|
705 |
+
},
|
706 |
+
"file_extension": ".py",
|
707 |
+
"mimetype": "text/x-python",
|
708 |
+
"name": "python",
|
709 |
+
"nbconvert_exporter": "python",
|
710 |
+
"pygments_lexer": "ipython3",
|
711 |
+
"version": "3.8.5"
|
712 |
+
},
|
713 |
+
"microsoft": {
|
714 |
+
"host": {
|
715 |
+
"AzureML": {
|
716 |
+
"notebookHasBeenCompleted": true
|
717 |
+
}
|
718 |
+
},
|
719 |
+
"ms_spell_check": {
|
720 |
+
"ms_spell_check_language": "en"
|
721 |
+
}
|
722 |
+
},
|
723 |
+
"nteract": {
|
724 |
+
"version": "nteract-front-end@1.0.0"
|
725 |
+
}
|
726 |
+
},
|
727 |
+
"nbformat": 4,
|
728 |
+
"nbformat_minor": 4
|
729 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-17-44-52Z.ipynb
ADDED
@@ -0,0 +1,739 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
8 |
+
"\n",
|
9 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"gather": {
|
17 |
+
"logged": 1706475754655
|
18 |
+
},
|
19 |
+
"nteract": {
|
20 |
+
"transient": {
|
21 |
+
"deleting": false
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"tags": []
|
25 |
+
},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
32 |
+
"Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
|
33 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
|
34 |
+
"Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
|
35 |
+
"Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
|
36 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
|
37 |
+
"Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
|
38 |
+
"Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
|
39 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
|
40 |
+
"Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
|
41 |
+
"Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
|
42 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
|
43 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
|
44 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
|
45 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
|
46 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
|
47 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
|
48 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
|
49 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
50 |
+
]
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"%pip install accelerate -U"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 2,
|
60 |
+
"metadata": {
|
61 |
+
"nteract": {
|
62 |
+
"transient": {
|
63 |
+
"deleting": false
|
64 |
+
}
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"outputs": [
|
68 |
+
{
|
69 |
+
"name": "stdout",
|
70 |
+
"output_type": "stream",
|
71 |
+
"text": [
|
72 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
73 |
+
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
|
74 |
+
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
|
75 |
+
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
|
76 |
+
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
|
77 |
+
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
|
78 |
+
"Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
|
79 |
+
"Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
|
80 |
+
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
|
81 |
+
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
|
82 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
|
83 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
|
84 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
|
85 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
|
86 |
+
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
|
87 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
|
88 |
+
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
|
89 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
|
90 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
|
91 |
+
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
|
92 |
+
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
|
93 |
+
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
|
94 |
+
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
|
95 |
+
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
|
96 |
+
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
|
97 |
+
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
|
98 |
+
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
|
99 |
+
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
|
100 |
+
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
|
101 |
+
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
|
102 |
+
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
|
103 |
+
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
|
104 |
+
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
|
105 |
+
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
|
106 |
+
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
|
107 |
+
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
|
108 |
+
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
|
109 |
+
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
|
110 |
+
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
|
111 |
+
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
|
112 |
+
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
|
113 |
+
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
|
114 |
+
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
|
115 |
+
"Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
|
116 |
+
"Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
|
117 |
+
"Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
|
118 |
+
"Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
|
119 |
+
"Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
|
120 |
+
"Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
|
121 |
+
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
|
122 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
|
123 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
|
124 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
125 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
|
126 |
+
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
127 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
128 |
+
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
|
129 |
+
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
|
130 |
+
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
|
131 |
+
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
|
132 |
+
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
|
133 |
+
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
|
134 |
+
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
|
135 |
+
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
|
136 |
+
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
|
137 |
+
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
|
138 |
+
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
|
139 |
+
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
|
140 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
|
141 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
|
142 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
|
143 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
|
144 |
+
"Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
|
145 |
+
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
|
146 |
+
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
|
147 |
+
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
148 |
+
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
149 |
+
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
|
150 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
|
151 |
+
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
|
152 |
+
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
|
153 |
+
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
|
154 |
+
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
|
155 |
+
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
|
156 |
+
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
|
157 |
+
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
|
158 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
159 |
+
]
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"source": [
|
163 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": 28,
|
169 |
+
"metadata": {
|
170 |
+
"datalore": {
|
171 |
+
"hide_input_from_viewers": false,
|
172 |
+
"hide_output_from_viewers": false,
|
173 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
174 |
+
"report_properties": {
|
175 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
176 |
+
},
|
177 |
+
"type": "CODE"
|
178 |
+
},
|
179 |
+
"gather": {
|
180 |
+
"logged": 1706503443742
|
181 |
+
},
|
182 |
+
"tags": []
|
183 |
+
},
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"data": {
|
187 |
+
"text/html": [
|
188 |
+
" View run <strong style=\"color:#cdcd00\">daedra_0.05-distilbert-base-uncased</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/cwkdl3x7</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v3</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
189 |
+
],
|
190 |
+
"text/plain": [
|
191 |
+
"<IPython.core.display.HTML object>"
|
192 |
+
]
|
193 |
+
},
|
194 |
+
"metadata": {},
|
195 |
+
"output_type": "display_data"
|
196 |
+
},
|
197 |
+
{
|
198 |
+
"data": {
|
199 |
+
"text/html": [
|
200 |
+
"Find logs at: <code>./wandb/run-20240129_152136-cwkdl3x7/logs</code>"
|
201 |
+
],
|
202 |
+
"text/plain": [
|
203 |
+
"<IPython.core.display.HTML object>"
|
204 |
+
]
|
205 |
+
},
|
206 |
+
"metadata": {},
|
207 |
+
"output_type": "display_data"
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"name": "stdout",
|
211 |
+
"output_type": "stream",
|
212 |
+
"text": [
|
213 |
+
"The watermark extension is already loaded. To reload it, use:\n",
|
214 |
+
" %reload_ext watermark\n"
|
215 |
+
]
|
216 |
+
}
|
217 |
+
],
|
218 |
+
"source": [
|
219 |
+
"import pandas as pd\n",
|
220 |
+
"import numpy as np\n",
|
221 |
+
"import torch\n",
|
222 |
+
"import os\n",
|
223 |
+
"from typing import List, Union\n",
|
224 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
|
225 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
226 |
+
"import shap\n",
|
227 |
+
"import wandb\n",
|
228 |
+
"import evaluate\n",
|
229 |
+
"import logging\n",
|
230 |
+
"\n",
|
231 |
+
"wandb.finish()\n",
|
232 |
+
"\n",
|
233 |
+
"\n",
|
234 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
235 |
+
"\n",
|
236 |
+
"%load_ext watermark"
|
237 |
+
]
|
238 |
+
},
|
239 |
+
{
|
240 |
+
"cell_type": "code",
|
241 |
+
"execution_count": 4,
|
242 |
+
"metadata": {
|
243 |
+
"collapsed": false,
|
244 |
+
"gather": {
|
245 |
+
"logged": 1706503443899
|
246 |
+
},
|
247 |
+
"jupyter": {
|
248 |
+
"outputs_hidden": false
|
249 |
+
}
|
250 |
+
},
|
251 |
+
"outputs": [],
|
252 |
+
"source": [
|
253 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
254 |
+
"\n",
|
255 |
+
"SEED: int = 42\n",
|
256 |
+
"\n",
|
257 |
+
"BATCH_SIZE: int = 32\n",
|
258 |
+
"EPOCHS: int = 5\n",
|
259 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
260 |
+
"\n",
|
261 |
+
"# WandB configuration\n",
|
262 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
263 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
264 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
265 |
+
]
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"execution_count": 5,
|
270 |
+
"metadata": {
|
271 |
+
"collapsed": false,
|
272 |
+
"jupyter": {
|
273 |
+
"outputs_hidden": false
|
274 |
+
}
|
275 |
+
},
|
276 |
+
"outputs": [
|
277 |
+
{
|
278 |
+
"name": "stdout",
|
279 |
+
"output_type": "stream",
|
280 |
+
"text": [
|
281 |
+
"re : 2.2.1\n",
|
282 |
+
"torch : 1.12.0\n",
|
283 |
+
"wandb : 0.16.2\n",
|
284 |
+
"logging : 0.5.1.2\n",
|
285 |
+
"numpy : 1.23.5\n",
|
286 |
+
"pandas : 2.0.2\n",
|
287 |
+
"evaluate: 0.4.1\n",
|
288 |
+
"shap : 0.44.1\n",
|
289 |
+
"\n"
|
290 |
+
]
|
291 |
+
}
|
292 |
+
],
|
293 |
+
"source": [
|
294 |
+
"%watermark --iversion"
|
295 |
+
]
|
296 |
+
},
|
297 |
+
{
|
298 |
+
"cell_type": "code",
|
299 |
+
"execution_count": 6,
|
300 |
+
"metadata": {
|
301 |
+
"datalore": {
|
302 |
+
"hide_input_from_viewers": true,
|
303 |
+
"hide_output_from_viewers": true,
|
304 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
305 |
+
"type": "CODE"
|
306 |
+
}
|
307 |
+
},
|
308 |
+
"outputs": [
|
309 |
+
{
|
310 |
+
"name": "stdout",
|
311 |
+
"output_type": "stream",
|
312 |
+
"text": [
|
313 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
314 |
+
"Mon Jan 29 15:20:22 2024 \n",
|
315 |
+
"+---------------------------------------------------------------------------------------+\n",
|
316 |
+
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
|
317 |
+
"|-----------------------------------------+----------------------+----------------------+\n",
|
318 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
319 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
320 |
+
"| | | MIG M. |\n",
|
321 |
+
"|=========================================+======================+======================|\n",
|
322 |
+
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
|
323 |
+
"| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
324 |
+
"| | | N/A |\n",
|
325 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
326 |
+
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
|
327 |
+
"| N/A 26C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
328 |
+
"| | | N/A |\n",
|
329 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
330 |
+
"| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\n",
|
331 |
+
"| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
332 |
+
"| | | N/A |\n",
|
333 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
334 |
+
"| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\n",
|
335 |
+
"| N/A 28C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
336 |
+
"| | | N/A |\n",
|
337 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
338 |
+
" \n",
|
339 |
+
"+---------------------------------------------------------------------------------------+\n",
|
340 |
+
"| Processes: |\n",
|
341 |
+
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
342 |
+
"| ID ID Usage |\n",
|
343 |
+
"|=======================================================================================|\n",
|
344 |
+
"| No running processes found |\n",
|
345 |
+
"+---------------------------------------------------------------------------------------+\n"
|
346 |
+
]
|
347 |
+
}
|
348 |
+
],
|
349 |
+
"source": [
|
350 |
+
"!nvidia-smi"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "markdown",
|
355 |
+
"metadata": {
|
356 |
+
"datalore": {
|
357 |
+
"hide_input_from_viewers": false,
|
358 |
+
"hide_output_from_viewers": false,
|
359 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
360 |
+
"report_properties": {
|
361 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
362 |
+
},
|
363 |
+
"type": "MD"
|
364 |
+
}
|
365 |
+
},
|
366 |
+
"source": [
|
367 |
+
"## Loading the data set"
|
368 |
+
]
|
369 |
+
},
|
370 |
+
{
|
371 |
+
"cell_type": "code",
|
372 |
+
"execution_count": 7,
|
373 |
+
"metadata": {
|
374 |
+
"collapsed": false,
|
375 |
+
"gather": {
|
376 |
+
"logged": 1706503446033
|
377 |
+
},
|
378 |
+
"jupyter": {
|
379 |
+
"outputs_hidden": false
|
380 |
+
}
|
381 |
+
},
|
382 |
+
"outputs": [],
|
383 |
+
"source": [
|
384 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
385 |
+
]
|
386 |
+
},
|
387 |
+
{
|
388 |
+
"cell_type": "code",
|
389 |
+
"execution_count": 8,
|
390 |
+
"metadata": {
|
391 |
+
"collapsed": false,
|
392 |
+
"gather": {
|
393 |
+
"logged": 1706503446252
|
394 |
+
},
|
395 |
+
"jupyter": {
|
396 |
+
"outputs_hidden": false,
|
397 |
+
"source_hidden": false
|
398 |
+
},
|
399 |
+
"nteract": {
|
400 |
+
"transient": {
|
401 |
+
"deleting": false
|
402 |
+
}
|
403 |
+
}
|
404 |
+
},
|
405 |
+
"outputs": [
|
406 |
+
{
|
407 |
+
"data": {
|
408 |
+
"text/plain": [
|
409 |
+
"DatasetDict({\n",
|
410 |
+
" train: Dataset({\n",
|
411 |
+
" features: ['id', 'text', 'label'],\n",
|
412 |
+
" num_rows: 1270444\n",
|
413 |
+
" })\n",
|
414 |
+
" test: Dataset({\n",
|
415 |
+
" features: ['id', 'text', 'label'],\n",
|
416 |
+
" num_rows: 272238\n",
|
417 |
+
" })\n",
|
418 |
+
" val: Dataset({\n",
|
419 |
+
" features: ['id', 'text', 'label'],\n",
|
420 |
+
" num_rows: 272238\n",
|
421 |
+
" })\n",
|
422 |
+
"})"
|
423 |
+
]
|
424 |
+
},
|
425 |
+
"execution_count": 8,
|
426 |
+
"metadata": {},
|
427 |
+
"output_type": "execute_result"
|
428 |
+
}
|
429 |
+
],
|
430 |
+
"source": [
|
431 |
+
"dataset"
|
432 |
+
]
|
433 |
+
},
|
434 |
+
{
|
435 |
+
"cell_type": "code",
|
436 |
+
"execution_count": 8,
|
437 |
+
"metadata": {
|
438 |
+
"gather": {
|
439 |
+
"logged": 1706503446498
|
440 |
+
}
|
441 |
+
},
|
442 |
+
"outputs": [],
|
443 |
+
"source": [
|
444 |
+
"SUBSAMPLING = 0.01\n",
|
445 |
+
"\n",
|
446 |
+
"if SUBSAMPLING < 1:\n",
|
447 |
+
" _ = DatasetDict()\n",
|
448 |
+
" for each in dataset.keys():\n",
|
449 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
450 |
+
"\n",
|
451 |
+
" dataset = _"
|
452 |
+
]
|
453 |
+
},
|
454 |
+
{
|
455 |
+
"cell_type": "markdown",
|
456 |
+
"metadata": {},
|
457 |
+
"source": [
|
458 |
+
"## Tokenisation and encoding"
|
459 |
+
]
|
460 |
+
},
|
461 |
+
{
|
462 |
+
"cell_type": "code",
|
463 |
+
"execution_count": 10,
|
464 |
+
"metadata": {
|
465 |
+
"gather": {
|
466 |
+
"logged": 1706503446633
|
467 |
+
}
|
468 |
+
},
|
469 |
+
"outputs": [],
|
470 |
+
"source": [
|
471 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
472 |
+
" return ds_enc"
|
473 |
+
]
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"cell_type": "markdown",
|
477 |
+
"metadata": {},
|
478 |
+
"source": [
|
479 |
+
"## Evaluation metrics"
|
480 |
+
]
|
481 |
+
},
|
482 |
+
{
|
483 |
+
"cell_type": "code",
|
484 |
+
"execution_count": 9,
|
485 |
+
"metadata": {
|
486 |
+
"gather": {
|
487 |
+
"logged": 1706503446863
|
488 |
+
}
|
489 |
+
},
|
490 |
+
"outputs": [],
|
491 |
+
"source": [
|
492 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
493 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
494 |
+
"f1 = evaluate.load(\"f1\")"
|
495 |
+
]
|
496 |
+
},
|
497 |
+
{
|
498 |
+
"cell_type": "code",
|
499 |
+
"execution_count": 10,
|
500 |
+
"metadata": {
|
501 |
+
"gather": {
|
502 |
+
"logged": 1706503447004
|
503 |
+
}
|
504 |
+
},
|
505 |
+
"outputs": [],
|
506 |
+
"source": [
|
507 |
+
"def compute_metrics(eval_pred):\n",
|
508 |
+
" predictions, labels = eval_pred\n",
|
509 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
510 |
+
" return {\n",
|
511 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
512 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
513 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
514 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
515 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
516 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
517 |
+
" }"
|
518 |
+
]
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"cell_type": "markdown",
|
522 |
+
"metadata": {},
|
523 |
+
"source": [
|
524 |
+
"## Training"
|
525 |
+
]
|
526 |
+
},
|
527 |
+
{
|
528 |
+
"cell_type": "markdown",
|
529 |
+
"metadata": {},
|
530 |
+
"source": [
|
531 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
532 |
+
]
|
533 |
+
},
|
534 |
+
{
|
535 |
+
"cell_type": "code",
|
536 |
+
"execution_count": 11,
|
537 |
+
"metadata": {
|
538 |
+
"gather": {
|
539 |
+
"logged": 1706503447186
|
540 |
+
}
|
541 |
+
},
|
542 |
+
"outputs": [],
|
543 |
+
"source": [
|
544 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
545 |
+
]
|
546 |
+
},
|
547 |
+
{
|
548 |
+
"cell_type": "code",
|
549 |
+
"execution_count": 12,
|
550 |
+
"metadata": {
|
551 |
+
"jupyter": {
|
552 |
+
"outputs_hidden": false,
|
553 |
+
"source_hidden": false
|
554 |
+
},
|
555 |
+
"nteract": {
|
556 |
+
"transient": {
|
557 |
+
"deleting": false
|
558 |
+
}
|
559 |
+
}
|
560 |
+
},
|
561 |
+
"outputs": [],
|
562 |
+
"source": [
|
563 |
+
"def train_from_model(model_ckpt: str, push: bool = False):\n",
|
564 |
+
" print(f\"Initialising training based on {model_ckpt}...\")\n",
|
565 |
+
"\n",
|
566 |
+
" print(\"Tokenising...\")\n",
|
567 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
568 |
+
"\n",
|
569 |
+
" cols = dataset[\"train\"].column_names\n",
|
570 |
+
" cols.remove(\"label\")\n",
|
571 |
+
" ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
|
572 |
+
"\n",
|
573 |
+
" print(\"Loading model...\")\n",
|
574 |
+
" try:\n",
|
575 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
576 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
577 |
+
" id2label=label_map, \n",
|
578 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
579 |
+
" except OSError:\n",
|
580 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
581 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
582 |
+
" id2label=label_map, \n",
|
583 |
+
" label2id={v:k for k,v in label_map.items()},\n",
|
584 |
+
" from_tf=True)\n",
|
585 |
+
"\n",
|
586 |
+
"\n",
|
587 |
+
" args = TrainingArguments(\n",
|
588 |
+
" output_dir=\"vaers\",\n",
|
589 |
+
" evaluation_strategy=\"epoch\",\n",
|
590 |
+
" save_strategy=\"epoch\",\n",
|
591 |
+
" learning_rate=2e-5,\n",
|
592 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
593 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
594 |
+
" num_train_epochs=EPOCHS,\n",
|
595 |
+
" weight_decay=.01,\n",
|
596 |
+
" logging_steps=1,\n",
|
597 |
+
" load_best_model_at_end=True,\n",
|
598 |
+
" run_name=f\"daedra-training\",\n",
|
599 |
+
" report_to=[\"wandb\"])\n",
|
600 |
+
"\n",
|
601 |
+
" trainer = Trainer(\n",
|
602 |
+
" model=model,\n",
|
603 |
+
" args=args,\n",
|
604 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
605 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
606 |
+
" tokenizer=tokenizer,\n",
|
607 |
+
" compute_metrics=compute_metrics)\n",
|
608 |
+
" \n",
|
609 |
+
" if SUBSAMPLING != 1.0:\n",
|
610 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
611 |
+
" else:\n",
|
612 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
613 |
+
"\n",
|
614 |
+
" wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
615 |
+
" wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
616 |
+
" \n",
|
617 |
+
" wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
|
618 |
+
"\n",
|
619 |
+
" print(\"Starting training...\")\n",
|
620 |
+
"\n",
|
621 |
+
" trainer.train()\n",
|
622 |
+
"\n",
|
623 |
+
" print(\"Training finished.\")\n",
|
624 |
+
"\n",
|
625 |
+
" if push:\n",
|
626 |
+
" variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
627 |
+
" tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
628 |
+
" tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
629 |
+
" sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
630 |
+
"\n",
|
631 |
+
" model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
632 |
+
" variant=variant,\n",
|
633 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
|
634 |
+
]
|
635 |
+
},
|
636 |
+
{
|
637 |
+
"cell_type": "code",
|
638 |
+
"execution_count": 13,
|
639 |
+
"metadata": {
|
640 |
+
"gather": {
|
641 |
+
"logged": 1706503552083
|
642 |
+
}
|
643 |
+
},
|
644 |
+
"outputs": [],
|
645 |
+
"source": [
|
646 |
+
"\n",
|
647 |
+
"base_models = [\n",
|
648 |
+
" \"bert-base-uncased\",\n",
|
649 |
+
" \"distilbert-base-uncased\",\n",
|
650 |
+
"]"
|
651 |
+
]
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"cell_type": "code",
|
655 |
+
"execution_count": null,
|
656 |
+
"metadata": {},
|
657 |
+
"outputs": [],
|
658 |
+
"source": [
|
659 |
+
"for md in base_models:\n",
|
660 |
+
" train_from_model(md)"
|
661 |
+
]
|
662 |
+
}
|
663 |
+
],
|
664 |
+
"metadata": {
|
665 |
+
"datalore": {
|
666 |
+
"base_environment": "default",
|
667 |
+
"computation_mode": "JUPYTER",
|
668 |
+
"package_manager": "pip",
|
669 |
+
"packages": [
|
670 |
+
{
|
671 |
+
"name": "datasets",
|
672 |
+
"source": "PIP",
|
673 |
+
"version": "2.16.1"
|
674 |
+
},
|
675 |
+
{
|
676 |
+
"name": "torch",
|
677 |
+
"source": "PIP",
|
678 |
+
"version": "2.1.2"
|
679 |
+
},
|
680 |
+
{
|
681 |
+
"name": "accelerate",
|
682 |
+
"source": "PIP",
|
683 |
+
"version": "0.26.1"
|
684 |
+
}
|
685 |
+
],
|
686 |
+
"report_row_ids": [
|
687 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
688 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
689 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
690 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
691 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
692 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
693 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
694 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
695 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
696 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
697 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
698 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
699 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
700 |
+
],
|
701 |
+
"version": 3
|
702 |
+
},
|
703 |
+
"kernel_info": {
|
704 |
+
"name": "python38-azureml-pt-tf"
|
705 |
+
},
|
706 |
+
"kernelspec": {
|
707 |
+
"display_name": "azureml_py38_PT_TF",
|
708 |
+
"language": "python",
|
709 |
+
"name": "python3"
|
710 |
+
},
|
711 |
+
"language_info": {
|
712 |
+
"codemirror_mode": {
|
713 |
+
"name": "ipython",
|
714 |
+
"version": 3
|
715 |
+
},
|
716 |
+
"file_extension": ".py",
|
717 |
+
"mimetype": "text/x-python",
|
718 |
+
"name": "python",
|
719 |
+
"nbconvert_exporter": "python",
|
720 |
+
"pygments_lexer": "ipython3",
|
721 |
+
"version": "3.8.5"
|
722 |
+
},
|
723 |
+
"microsoft": {
|
724 |
+
"host": {
|
725 |
+
"AzureML": {
|
726 |
+
"notebookHasBeenCompleted": true
|
727 |
+
}
|
728 |
+
},
|
729 |
+
"ms_spell_check": {
|
730 |
+
"ms_spell_check_language": "en"
|
731 |
+
}
|
732 |
+
},
|
733 |
+
"nteract": {
|
734 |
+
"version": "nteract-front-end@1.0.0"
|
735 |
+
}
|
736 |
+
},
|
737 |
+
"nbformat": 4,
|
738 |
+
"nbformat_minor": 4
|
739 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-3-40-27Z.ipynb
ADDED
@@ -0,0 +1,1001 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
8 |
+
"\n",
|
9 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"gather": {
|
17 |
+
"logged": 1706475754655
|
18 |
+
},
|
19 |
+
"nteract": {
|
20 |
+
"transient": {
|
21 |
+
"deleting": false
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"tags": []
|
25 |
+
},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
|
32 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
|
33 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
|
34 |
+
"Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
|
35 |
+
"Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
|
36 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
|
37 |
+
"Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
|
38 |
+
"Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
|
39 |
+
"Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
|
40 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
|
41 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
|
42 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
|
43 |
+
"Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
|
44 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
|
45 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
|
46 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
|
47 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
|
48 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
49 |
+
]
|
50 |
+
}
|
51 |
+
],
|
52 |
+
"source": [
|
53 |
+
"%pip install accelerate -U"
|
54 |
+
]
|
55 |
+
},
|
56 |
+
{
|
57 |
+
"cell_type": "code",
|
58 |
+
"execution_count": 2,
|
59 |
+
"metadata": {
|
60 |
+
"nteract": {
|
61 |
+
"transient": {
|
62 |
+
"deleting": false
|
63 |
+
}
|
64 |
+
}
|
65 |
+
},
|
66 |
+
"outputs": [
|
67 |
+
{
|
68 |
+
"name": "stdout",
|
69 |
+
"output_type": "stream",
|
70 |
+
"text": [
|
71 |
+
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
|
72 |
+
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
|
73 |
+
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
|
74 |
+
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
|
75 |
+
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
|
76 |
+
"Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
|
77 |
+
"Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
|
78 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
|
79 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
|
80 |
+
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
|
81 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
|
82 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
|
83 |
+
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
|
84 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
|
85 |
+
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
|
86 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
|
87 |
+
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
|
88 |
+
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
|
89 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
|
90 |
+
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
|
91 |
+
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
|
92 |
+
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
|
93 |
+
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
|
94 |
+
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
|
95 |
+
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
|
96 |
+
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
|
97 |
+
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
|
98 |
+
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
|
99 |
+
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
|
100 |
+
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
|
101 |
+
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
|
102 |
+
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
|
103 |
+
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
|
104 |
+
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
|
105 |
+
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
|
106 |
+
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
|
107 |
+
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
|
108 |
+
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
|
109 |
+
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
|
110 |
+
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
|
111 |
+
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
|
112 |
+
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
|
113 |
+
"Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
|
114 |
+
"Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
|
115 |
+
"Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
|
116 |
+
"Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
|
117 |
+
"Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
|
118 |
+
"Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
|
119 |
+
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
|
120 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
|
121 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
|
122 |
+
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
123 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
|
124 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
125 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
126 |
+
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
|
127 |
+
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
|
128 |
+
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
|
129 |
+
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
|
130 |
+
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
|
131 |
+
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
|
132 |
+
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
|
133 |
+
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
|
134 |
+
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
|
135 |
+
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
|
136 |
+
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
|
137 |
+
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
|
138 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
|
139 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
|
140 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
|
141 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
|
142 |
+
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
|
143 |
+
"Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
|
144 |
+
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
|
145 |
+
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
146 |
+
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
147 |
+
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
|
148 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
|
149 |
+
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
|
150 |
+
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
|
151 |
+
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
|
152 |
+
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
|
153 |
+
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
|
154 |
+
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
|
155 |
+
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
|
156 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
157 |
+
]
|
158 |
+
}
|
159 |
+
],
|
160 |
+
"source": [
|
161 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
162 |
+
]
|
163 |
+
},
|
164 |
+
{
|
165 |
+
"cell_type": "code",
|
166 |
+
"execution_count": 3,
|
167 |
+
"metadata": {
|
168 |
+
"datalore": {
|
169 |
+
"hide_input_from_viewers": false,
|
170 |
+
"hide_output_from_viewers": false,
|
171 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
172 |
+
"report_properties": {
|
173 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
174 |
+
},
|
175 |
+
"type": "CODE"
|
176 |
+
},
|
177 |
+
"gather": {
|
178 |
+
"logged": 1706486372154
|
179 |
+
},
|
180 |
+
"tags": []
|
181 |
+
},
|
182 |
+
"outputs": [
|
183 |
+
{
|
184 |
+
"name": "stderr",
|
185 |
+
"output_type": "stream",
|
186 |
+
"text": [
|
187 |
+
"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n",
|
188 |
+
" from .autonotebook import tqdm as notebook_tqdm\n",
|
189 |
+
"2024-01-28 23:59:27.034680: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\n",
|
190 |
+
"To enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n",
|
191 |
+
"2024-01-28 23:59:27.996419: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n",
|
192 |
+
"2024-01-28 23:59:27.999143: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n",
|
193 |
+
"2024-01-28 23:59:27.999161: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n",
|
194 |
+
"[codecarbon INFO @ 23:59:30] [setup] RAM Tracking...\n",
|
195 |
+
"[codecarbon INFO @ 23:59:30] [setup] GPU Tracking...\n",
|
196 |
+
"[codecarbon INFO @ 23:59:30] Tracking Nvidia GPU via pynvml\n",
|
197 |
+
"[codecarbon INFO @ 23:59:30] [setup] CPU Tracking...\n",
|
198 |
+
"[codecarbon WARNING @ 23:59:30] No CPU tracking mode found. Falling back on CPU constant mode.\n",
|
199 |
+
"[codecarbon WARNING @ 23:59:31] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n",
|
200 |
+
"[codecarbon INFO @ 23:59:31] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
|
201 |
+
"[codecarbon INFO @ 23:59:31] >>> Tracker's metadata:\n",
|
202 |
+
"[codecarbon INFO @ 23:59:31] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n",
|
203 |
+
"[codecarbon INFO @ 23:59:31] Python version: 3.8.5\n",
|
204 |
+
"[codecarbon INFO @ 23:59:31] CodeCarbon version: 2.3.3\n",
|
205 |
+
"[codecarbon INFO @ 23:59:31] Available RAM : 440.883 GB\n",
|
206 |
+
"[codecarbon INFO @ 23:59:31] CPU count: 24\n",
|
207 |
+
"[codecarbon INFO @ 23:59:31] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
|
208 |
+
"[codecarbon INFO @ 23:59:31] GPU count: 4\n",
|
209 |
+
"[codecarbon INFO @ 23:59:31] GPU model: 4 x Tesla V100-PCIE-16GB\n",
|
210 |
+
"[codecarbon WARNING @ 23:59:32] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
|
211 |
+
]
|
212 |
+
}
|
213 |
+
],
|
214 |
+
"source": [
|
215 |
+
"import pandas as pd\n",
|
216 |
+
"import numpy as np\n",
|
217 |
+
"import torch\n",
|
218 |
+
"import os\n",
|
219 |
+
"from typing import List, Union\n",
|
220 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
|
221 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
222 |
+
"import shap\n",
|
223 |
+
"import wandb\n",
|
224 |
+
"import evaluate\n",
|
225 |
+
"from codecarbon import EmissionsTracker\n",
|
226 |
+
"import logging\n",
|
227 |
+
"\n",
|
228 |
+
"logging.getLogger('codecarbon').propagate = False\n",
|
229 |
+
"\n",
|
230 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
231 |
+
"tracker = EmissionsTracker()\n",
|
232 |
+
"\n",
|
233 |
+
"%load_ext watermark"
|
234 |
+
]
|
235 |
+
},
|
236 |
+
{
|
237 |
+
"cell_type": "code",
|
238 |
+
"execution_count": 4,
|
239 |
+
"metadata": {
|
240 |
+
"collapsed": false,
|
241 |
+
"gather": {
|
242 |
+
"logged": 1706486372304
|
243 |
+
},
|
244 |
+
"jupyter": {
|
245 |
+
"outputs_hidden": false
|
246 |
+
}
|
247 |
+
},
|
248 |
+
"outputs": [],
|
249 |
+
"source": [
|
250 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
251 |
+
"\n",
|
252 |
+
"SEED: int = 42\n",
|
253 |
+
"\n",
|
254 |
+
"BATCH_SIZE: int = 32\n",
|
255 |
+
"EPOCHS: int = 3\n",
|
256 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
257 |
+
"\n",
|
258 |
+
"# WandB configuration\n",
|
259 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
260 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
261 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
262 |
+
]
|
263 |
+
},
|
264 |
+
{
|
265 |
+
"cell_type": "code",
|
266 |
+
"execution_count": 5,
|
267 |
+
"metadata": {
|
268 |
+
"collapsed": false,
|
269 |
+
"jupyter": {
|
270 |
+
"outputs_hidden": false
|
271 |
+
}
|
272 |
+
},
|
273 |
+
"outputs": [
|
274 |
+
{
|
275 |
+
"name": "stdout",
|
276 |
+
"output_type": "stream",
|
277 |
+
"text": [
|
278 |
+
"re : 2.2.1\n",
|
279 |
+
"evaluate: 0.4.1\n",
|
280 |
+
"pandas : 2.0.2\n",
|
281 |
+
"wandb : 0.16.2\n",
|
282 |
+
"numpy : 1.23.5\n",
|
283 |
+
"torch : 1.12.0\n",
|
284 |
+
"logging : 0.5.1.2\n",
|
285 |
+
"shap : 0.44.1\n",
|
286 |
+
"\n"
|
287 |
+
]
|
288 |
+
}
|
289 |
+
],
|
290 |
+
"source": [
|
291 |
+
"%watermark --iversion"
|
292 |
+
]
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "code",
|
296 |
+
"execution_count": 6,
|
297 |
+
"metadata": {
|
298 |
+
"datalore": {
|
299 |
+
"hide_input_from_viewers": true,
|
300 |
+
"hide_output_from_viewers": true,
|
301 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
302 |
+
"type": "CODE"
|
303 |
+
}
|
304 |
+
},
|
305 |
+
"outputs": [
|
306 |
+
{
|
307 |
+
"name": "stdout",
|
308 |
+
"output_type": "stream",
|
309 |
+
"text": [
|
310 |
+
"Sun Jan 28 23:59:32 2024 \r\n",
|
311 |
+
"+---------------------------------------------------------------------------------------+\r\n",
|
312 |
+
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n",
|
313 |
+
"|-----------------------------------------+----------------------+----------------------+\r\n",
|
314 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n",
|
315 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n",
|
316 |
+
"| | | MIG M. |\r\n",
|
317 |
+
"|=========================================+======================+======================|\r\n",
|
318 |
+
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n",
|
319 |
+
"| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
|
320 |
+
"| | | N/A |\r\n",
|
321 |
+
"+-----------------------------------------+----------------------+----------------------+\r\n",
|
322 |
+
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n",
|
323 |
+
"| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
|
324 |
+
"| | | N/A |\r\n",
|
325 |
+
"+-----------------------------------------+----------------------+----------------------+\r\n",
|
326 |
+
"| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n",
|
327 |
+
"| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
|
328 |
+
"| | | N/A |\r\n",
|
329 |
+
"+-----------------------------------------+----------------------+----------------------+\r\n",
|
330 |
+
"| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n",
|
331 |
+
"| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n",
|
332 |
+
"| | | N/A |\r\n",
|
333 |
+
"+-----------------------------------------+----------------------+----------------------+\r\n",
|
334 |
+
" \r\n",
|
335 |
+
"+---------------------------------------------------------------------------------------+\r\n",
|
336 |
+
"| Processes: |\r\n",
|
337 |
+
"| GPU GI CI PID Type Process name GPU Memory |\r\n",
|
338 |
+
"| ID ID Usage |\r\n",
|
339 |
+
"|=======================================================================================|\r\n",
|
340 |
+
"| No running processes found |\r\n",
|
341 |
+
"+---------------------------------------------------------------------------------------+\r\n"
|
342 |
+
]
|
343 |
+
}
|
344 |
+
],
|
345 |
+
"source": [
|
346 |
+
"!nvidia-smi"
|
347 |
+
]
|
348 |
+
},
|
349 |
+
{
|
350 |
+
"cell_type": "markdown",
|
351 |
+
"metadata": {
|
352 |
+
"datalore": {
|
353 |
+
"hide_input_from_viewers": false,
|
354 |
+
"hide_output_from_viewers": false,
|
355 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
356 |
+
"report_properties": {
|
357 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
358 |
+
},
|
359 |
+
"type": "MD"
|
360 |
+
}
|
361 |
+
},
|
362 |
+
"source": [
|
363 |
+
"## Loading the data set"
|
364 |
+
]
|
365 |
+
},
|
366 |
+
{
|
367 |
+
"cell_type": "code",
|
368 |
+
"execution_count": 7,
|
369 |
+
"metadata": {
|
370 |
+
"collapsed": false,
|
371 |
+
"gather": {
|
372 |
+
"logged": 1706486373931
|
373 |
+
},
|
374 |
+
"jupyter": {
|
375 |
+
"outputs_hidden": false
|
376 |
+
}
|
377 |
+
},
|
378 |
+
"outputs": [],
|
379 |
+
"source": [
|
380 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
381 |
+
]
|
382 |
+
},
|
383 |
+
{
|
384 |
+
"cell_type": "code",
|
385 |
+
"execution_count": 8,
|
386 |
+
"metadata": {
|
387 |
+
"collapsed": false,
|
388 |
+
"gather": {
|
389 |
+
"logged": 1706486374218
|
390 |
+
},
|
391 |
+
"jupyter": {
|
392 |
+
"outputs_hidden": false,
|
393 |
+
"source_hidden": false
|
394 |
+
},
|
395 |
+
"nteract": {
|
396 |
+
"transient": {
|
397 |
+
"deleting": false
|
398 |
+
}
|
399 |
+
}
|
400 |
+
},
|
401 |
+
"outputs": [
|
402 |
+
{
|
403 |
+
"data": {
|
404 |
+
"text/plain": [
|
405 |
+
"DatasetDict({\n",
|
406 |
+
" train: Dataset({\n",
|
407 |
+
" features: ['id', 'text', 'label'],\n",
|
408 |
+
" num_rows: 1270444\n",
|
409 |
+
" })\n",
|
410 |
+
" test: Dataset({\n",
|
411 |
+
" features: ['id', 'text', 'label'],\n",
|
412 |
+
" num_rows: 272238\n",
|
413 |
+
" })\n",
|
414 |
+
" val: Dataset({\n",
|
415 |
+
" features: ['id', 'text', 'label'],\n",
|
416 |
+
" num_rows: 272238\n",
|
417 |
+
" })\n",
|
418 |
+
"})"
|
419 |
+
]
|
420 |
+
},
|
421 |
+
"execution_count": 8,
|
422 |
+
"metadata": {},
|
423 |
+
"output_type": "execute_result"
|
424 |
+
}
|
425 |
+
],
|
426 |
+
"source": [
|
427 |
+
"dataset"
|
428 |
+
]
|
429 |
+
},
|
430 |
+
{
|
431 |
+
"cell_type": "code",
|
432 |
+
"execution_count": 9,
|
433 |
+
"metadata": {
|
434 |
+
"gather": {
|
435 |
+
"logged": 1706486374480
|
436 |
+
}
|
437 |
+
},
|
438 |
+
"outputs": [],
|
439 |
+
"source": [
|
440 |
+
"SUBSAMPLING = 0.5\n",
|
441 |
+
"\n",
|
442 |
+
"if SUBSAMPLING < 1:\n",
|
443 |
+
" _ = DatasetDict()\n",
|
444 |
+
" for each in dataset.keys():\n",
|
445 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
446 |
+
"\n",
|
447 |
+
" dataset = _"
|
448 |
+
]
|
449 |
+
},
|
450 |
+
{
|
451 |
+
"cell_type": "markdown",
|
452 |
+
"metadata": {},
|
453 |
+
"source": [
|
454 |
+
"## Tokenisation and encoding"
|
455 |
+
]
|
456 |
+
},
|
457 |
+
{
|
458 |
+
"cell_type": "code",
|
459 |
+
"execution_count": 10,
|
460 |
+
"metadata": {
|
461 |
+
"gather": {
|
462 |
+
"logged": 1706486375030
|
463 |
+
}
|
464 |
+
},
|
465 |
+
"outputs": [],
|
466 |
+
"source": [
|
467 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
468 |
+
" return ds_enc"
|
469 |
+
]
|
470 |
+
},
|
471 |
+
{
|
472 |
+
"cell_type": "markdown",
|
473 |
+
"metadata": {},
|
474 |
+
"source": [
|
475 |
+
"## Evaluation metrics"
|
476 |
+
]
|
477 |
+
},
|
478 |
+
{
|
479 |
+
"cell_type": "code",
|
480 |
+
"execution_count": 11,
|
481 |
+
"metadata": {
|
482 |
+
"gather": {
|
483 |
+
"logged": 1706486375197
|
484 |
+
}
|
485 |
+
},
|
486 |
+
"outputs": [],
|
487 |
+
"source": [
|
488 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
489 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
490 |
+
"f1 = evaluate.load(\"f1\")"
|
491 |
+
]
|
492 |
+
},
|
493 |
+
{
|
494 |
+
"cell_type": "code",
|
495 |
+
"execution_count": 12,
|
496 |
+
"metadata": {
|
497 |
+
"gather": {
|
498 |
+
"logged": 1706486375361
|
499 |
+
}
|
500 |
+
},
|
501 |
+
"outputs": [],
|
502 |
+
"source": [
|
503 |
+
"def compute_metrics(eval_pred):\n",
|
504 |
+
" predictions, labels = eval_pred\n",
|
505 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
506 |
+
" return {\n",
|
507 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
508 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
509 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
510 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
511 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
512 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
513 |
+
" }"
|
514 |
+
]
|
515 |
+
},
|
516 |
+
{
|
517 |
+
"cell_type": "markdown",
|
518 |
+
"metadata": {},
|
519 |
+
"source": [
|
520 |
+
"## Training"
|
521 |
+
]
|
522 |
+
},
|
523 |
+
{
|
524 |
+
"cell_type": "markdown",
|
525 |
+
"metadata": {},
|
526 |
+
"source": [
|
527 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
528 |
+
]
|
529 |
+
},
|
530 |
+
{
|
531 |
+
"cell_type": "code",
|
532 |
+
"execution_count": 13,
|
533 |
+
"metadata": {
|
534 |
+
"gather": {
|
535 |
+
"logged": 1706486375569
|
536 |
+
}
|
537 |
+
},
|
538 |
+
"outputs": [],
|
539 |
+
"source": [
|
540 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
541 |
+
]
|
542 |
+
},
|
543 |
+
{
|
544 |
+
"cell_type": "code",
|
545 |
+
"execution_count": 14,
|
546 |
+
"metadata": {
|
547 |
+
"gather": {
|
548 |
+
"logged": 1706486433708
|
549 |
+
}
|
550 |
+
},
|
551 |
+
"outputs": [
|
552 |
+
{
|
553 |
+
"name": "stderr",
|
554 |
+
"output_type": "stream",
|
555 |
+
"text": [
|
556 |
+
"Map: 100%|██████████| 136119/136119 [00:56<00:00, 2412.00 examples/s]\n",
|
557 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
|
558 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
559 |
+
]
|
560 |
+
}
|
561 |
+
],
|
562 |
+
"source": [
|
563 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
564 |
+
"\n",
|
565 |
+
"cols = dataset[\"train\"].column_names\n",
|
566 |
+
"cols.remove(\"label\")\n",
|
567 |
+
"ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n",
|
568 |
+
"\n",
|
569 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
570 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
571 |
+
" id2label=label_map, \n",
|
572 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
573 |
+
"\n",
|
574 |
+
"args = TrainingArguments(\n",
|
575 |
+
" output_dir=\"vaers\",\n",
|
576 |
+
" evaluation_strategy=\"epoch\",\n",
|
577 |
+
" save_strategy=\"epoch\",\n",
|
578 |
+
" learning_rate=2e-5,\n",
|
579 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
580 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
581 |
+
" num_train_epochs=EPOCHS,\n",
|
582 |
+
" weight_decay=.01,\n",
|
583 |
+
" logging_steps=1,\n",
|
584 |
+
" load_best_model_at_end=True,\n",
|
585 |
+
" run_name=f\"daedra-training\",\n",
|
586 |
+
" report_to=[\"wandb\"])\n",
|
587 |
+
"\n",
|
588 |
+
"trainer = Trainer(\n",
|
589 |
+
" model=model,\n",
|
590 |
+
" args=args,\n",
|
591 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
592 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
593 |
+
" tokenizer=tokenizer,\n",
|
594 |
+
" compute_metrics=compute_metrics)"
|
595 |
+
]
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "code",
|
599 |
+
"execution_count": 15,
|
600 |
+
"metadata": {
|
601 |
+
"gather": {
|
602 |
+
"logged": 1706486444806
|
603 |
+
}
|
604 |
+
},
|
605 |
+
"outputs": [
|
606 |
+
{
|
607 |
+
"name": "stderr",
|
608 |
+
"output_type": "stream",
|
609 |
+
"text": [
|
610 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
|
611 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
612 |
+
]
|
613 |
+
},
|
614 |
+
{
|
615 |
+
"data": {
|
616 |
+
"text/html": [
|
617 |
+
"Tracking run with wandb version 0.16.2"
|
618 |
+
],
|
619 |
+
"text/plain": [
|
620 |
+
"<IPython.core.display.HTML object>"
|
621 |
+
]
|
622 |
+
},
|
623 |
+
"metadata": {},
|
624 |
+
"output_type": "display_data"
|
625 |
+
},
|
626 |
+
{
|
627 |
+
"data": {
|
628 |
+
"text/html": [
|
629 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_000035-xm7aguww</code>"
|
630 |
+
],
|
631 |
+
"text/plain": [
|
632 |
+
"<IPython.core.display.HTML object>"
|
633 |
+
]
|
634 |
+
},
|
635 |
+
"metadata": {},
|
636 |
+
"output_type": "display_data"
|
637 |
+
},
|
638 |
+
{
|
639 |
+
"data": {
|
640 |
+
"text/html": [
|
641 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
642 |
+
],
|
643 |
+
"text/plain": [
|
644 |
+
"<IPython.core.display.HTML object>"
|
645 |
+
]
|
646 |
+
},
|
647 |
+
"metadata": {},
|
648 |
+
"output_type": "display_data"
|
649 |
+
},
|
650 |
+
{
|
651 |
+
"data": {
|
652 |
+
"text/html": [
|
653 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
654 |
+
],
|
655 |
+
"text/plain": [
|
656 |
+
"<IPython.core.display.HTML object>"
|
657 |
+
]
|
658 |
+
},
|
659 |
+
"metadata": {},
|
660 |
+
"output_type": "display_data"
|
661 |
+
},
|
662 |
+
{
|
663 |
+
"data": {
|
664 |
+
"text/html": [
|
665 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww</a>"
|
666 |
+
],
|
667 |
+
"text/plain": [
|
668 |
+
"<IPython.core.display.HTML object>"
|
669 |
+
]
|
670 |
+
},
|
671 |
+
"metadata": {},
|
672 |
+
"output_type": "display_data"
|
673 |
+
},
|
674 |
+
{
|
675 |
+
"data": {
|
676 |
+
"text/html": [
|
677 |
+
"Finishing last run (ID:xm7aguww) before initializing another..."
|
678 |
+
],
|
679 |
+
"text/plain": [
|
680 |
+
"<IPython.core.display.HTML object>"
|
681 |
+
]
|
682 |
+
},
|
683 |
+
"metadata": {},
|
684 |
+
"output_type": "display_data"
|
685 |
+
},
|
686 |
+
{
|
687 |
+
"data": {
|
688 |
+
"text/html": [
|
689 |
+
" View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/xm7aguww</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
690 |
+
],
|
691 |
+
"text/plain": [
|
692 |
+
"<IPython.core.display.HTML object>"
|
693 |
+
]
|
694 |
+
},
|
695 |
+
"metadata": {},
|
696 |
+
"output_type": "display_data"
|
697 |
+
},
|
698 |
+
{
|
699 |
+
"data": {
|
700 |
+
"text/html": [
|
701 |
+
"Find logs at: <code>./wandb/run-20240129_000035-xm7aguww/logs</code>"
|
702 |
+
],
|
703 |
+
"text/plain": [
|
704 |
+
"<IPython.core.display.HTML object>"
|
705 |
+
]
|
706 |
+
},
|
707 |
+
"metadata": {},
|
708 |
+
"output_type": "display_data"
|
709 |
+
},
|
710 |
+
{
|
711 |
+
"data": {
|
712 |
+
"text/html": [
|
713 |
+
"Successfully finished last run (ID:xm7aguww). Initializing new run:<br/>"
|
714 |
+
],
|
715 |
+
"text/plain": [
|
716 |
+
"<IPython.core.display.HTML object>"
|
717 |
+
]
|
718 |
+
},
|
719 |
+
"metadata": {},
|
720 |
+
"output_type": "display_data"
|
721 |
+
},
|
722 |
+
{
|
723 |
+
"data": {
|
724 |
+
"text/html": [
|
725 |
+
"Tracking run with wandb version 0.16.2"
|
726 |
+
],
|
727 |
+
"text/plain": [
|
728 |
+
"<IPython.core.display.HTML object>"
|
729 |
+
]
|
730 |
+
},
|
731 |
+
"metadata": {},
|
732 |
+
"output_type": "display_data"
|
733 |
+
},
|
734 |
+
{
|
735 |
+
"data": {
|
736 |
+
"text/html": [
|
737 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_000037-qfvjuxwm</code>"
|
738 |
+
],
|
739 |
+
"text/plain": [
|
740 |
+
"<IPython.core.display.HTML object>"
|
741 |
+
]
|
742 |
+
},
|
743 |
+
"metadata": {},
|
744 |
+
"output_type": "display_data"
|
745 |
+
},
|
746 |
+
{
|
747 |
+
"data": {
|
748 |
+
"text/html": [
|
749 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
750 |
+
],
|
751 |
+
"text/plain": [
|
752 |
+
"<IPython.core.display.HTML object>"
|
753 |
+
]
|
754 |
+
},
|
755 |
+
"metadata": {},
|
756 |
+
"output_type": "display_data"
|
757 |
+
},
|
758 |
+
{
|
759 |
+
"data": {
|
760 |
+
"text/html": [
|
761 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
762 |
+
],
|
763 |
+
"text/plain": [
|
764 |
+
"<IPython.core.display.HTML object>"
|
765 |
+
]
|
766 |
+
},
|
767 |
+
"metadata": {},
|
768 |
+
"output_type": "display_data"
|
769 |
+
},
|
770 |
+
{
|
771 |
+
"data": {
|
772 |
+
"text/html": [
|
773 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm</a>"
|
774 |
+
],
|
775 |
+
"text/plain": [
|
776 |
+
"<IPython.core.display.HTML object>"
|
777 |
+
]
|
778 |
+
},
|
779 |
+
"metadata": {},
|
780 |
+
"output_type": "display_data"
|
781 |
+
},
|
782 |
+
{
|
783 |
+
"data": {
|
784 |
+
"text/html": [
|
785 |
+
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/qfvjuxwm?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
|
786 |
+
],
|
787 |
+
"text/plain": [
|
788 |
+
"<wandb.sdk.wandb_run.Run at 0x7f6e1c1e64c0>"
|
789 |
+
]
|
790 |
+
},
|
791 |
+
"execution_count": 15,
|
792 |
+
"metadata": {},
|
793 |
+
"output_type": "execute_result"
|
794 |
+
}
|
795 |
+
],
|
796 |
+
"source": [
|
797 |
+
"if SUBSAMPLING != 1.0:\n",
|
798 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
799 |
+
"else:\n",
|
800 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
801 |
+
"\n",
|
802 |
+
"wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
803 |
+
"wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
804 |
+
" \n",
|
805 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
|
806 |
+
]
|
807 |
+
},
|
808 |
+
{
|
809 |
+
"cell_type": "code",
|
810 |
+
"execution_count": 16,
|
811 |
+
"metadata": {
|
812 |
+
"gather": {
|
813 |
+
"logged": 1706486541798
|
814 |
+
}
|
815 |
+
},
|
816 |
+
"outputs": [
|
817 |
+
{
|
818 |
+
"name": "stderr",
|
819 |
+
"output_type": "stream",
|
820 |
+
"text": [
|
821 |
+
"Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
822 |
+
]
|
823 |
+
},
|
824 |
+
{
|
825 |
+
"data": {
|
826 |
+
"text/html": [
|
827 |
+
"\n",
|
828 |
+
" <div>\n",
|
829 |
+
" \n",
|
830 |
+
" <progress value='138' max='14889' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
831 |
+
" [ 138/14889 01:27 < 2:38:40, 1.55 it/s, Epoch 0.03/3]\n",
|
832 |
+
" </div>\n",
|
833 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
834 |
+
" <thead>\n",
|
835 |
+
" <tr style=\"text-align: left;\">\n",
|
836 |
+
" <th>Epoch</th>\n",
|
837 |
+
" <th>Training Loss</th>\n",
|
838 |
+
" <th>Validation Loss</th>\n",
|
839 |
+
" </tr>\n",
|
840 |
+
" </thead>\n",
|
841 |
+
" <tbody>\n",
|
842 |
+
" </tbody>\n",
|
843 |
+
"</table><p>"
|
844 |
+
],
|
845 |
+
"text/plain": [
|
846 |
+
"<IPython.core.display.HTML object>"
|
847 |
+
]
|
848 |
+
},
|
849 |
+
"metadata": {},
|
850 |
+
"output_type": "display_data"
|
851 |
+
},
|
852 |
+
{
|
853 |
+
"name": "stderr",
|
854 |
+
"output_type": "stream",
|
855 |
+
"text": [
|
856 |
+
"[codecarbon INFO @ 00:00:59] Energy consumed for RAM : 0.000690 kWh. RAM Power : 165.33123922348022 W\n",
|
857 |
+
"[codecarbon INFO @ 00:00:59] Energy consumed for all GPUs : 0.001468 kWh. Total GPU Power : 351.7461416352095 W\n",
|
858 |
+
"[codecarbon INFO @ 00:00:59] Energy consumed for all CPUs : 0.000178 kWh. Total CPU Power : 42.5 W\n",
|
859 |
+
"[codecarbon INFO @ 00:00:59] 0.002336 kWh of electricity used since the beginning.\n",
|
860 |
+
"[codecarbon INFO @ 00:01:14] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n",
|
861 |
+
"[codecarbon INFO @ 00:01:14] Energy consumed for all GPUs : 0.004025 kWh. Total GPU Power : 614.3289592286081 W\n",
|
862 |
+
"[codecarbon INFO @ 00:01:14] Energy consumed for all CPUs : 0.000355 kWh. Total CPU Power : 42.5 W\n",
|
863 |
+
"[codecarbon INFO @ 00:01:14] 0.005757 kWh of electricity used since the beginning.\n",
|
864 |
+
"[codecarbon INFO @ 00:01:29] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n",
|
865 |
+
"[codecarbon INFO @ 00:01:29] Energy consumed for all GPUs : 0.006586 kWh. Total GPU Power : 615.1209943099732 W\n",
|
866 |
+
"[codecarbon INFO @ 00:01:29] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n",
|
867 |
+
"[codecarbon INFO @ 00:01:29] 0.009184 kWh of electricity used since the beginning.\n",
|
868 |
+
"[codecarbon INFO @ 00:01:44] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n",
|
869 |
+
"[codecarbon INFO @ 00:01:44] Energy consumed for all GPUs : 0.009201 kWh. Total GPU Power : 628.2177623002755 W\n",
|
870 |
+
"[codecarbon INFO @ 00:01:44] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n",
|
871 |
+
"[codecarbon INFO @ 00:01:44] 0.012664 kWh of electricity used since the beginning.\n",
|
872 |
+
"[codecarbon INFO @ 00:01:59] Energy consumed for RAM : 0.003442 kWh. RAM Power : 165.33123922348022 W\n",
|
873 |
+
"[codecarbon INFO @ 00:01:59] Energy consumed for all GPUs : 0.011831 kWh. Total GPU Power : 631.8056507544826 W\n",
|
874 |
+
"[codecarbon INFO @ 00:01:59] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n",
|
875 |
+
"[codecarbon INFO @ 00:01:59] 0.016159 kWh of electricity used since the beginning.\n",
|
876 |
+
"[codecarbon INFO @ 00:02:14] Energy consumed for RAM : 0.004130 kWh. RAM Power : 165.33123922348022 W\n",
|
877 |
+
"[codecarbon INFO @ 00:02:14] Energy consumed for all GPUs : 0.014450 kWh. Total GPU Power : 629.2086149888297 W\n",
|
878 |
+
"[codecarbon INFO @ 00:02:14] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n",
|
879 |
+
"[codecarbon INFO @ 00:02:14] 0.019643 kWh of electricity used since the beginning.\n",
|
880 |
+
"\n",
|
881 |
+
"KeyboardInterrupt\n",
|
882 |
+
"\n"
|
883 |
+
]
|
884 |
+
}
|
885 |
+
],
|
886 |
+
"source": [
|
887 |
+
"tracker.start()\n",
|
888 |
+
"trainer.train()\n",
|
889 |
+
"tracker.stop()\n"
|
890 |
+
]
|
891 |
+
},
|
892 |
+
{
|
893 |
+
"cell_type": "code",
|
894 |
+
"execution_count": null,
|
895 |
+
"metadata": {
|
896 |
+
"gather": {
|
897 |
+
"logged": 1706486541918
|
898 |
+
}
|
899 |
+
},
|
900 |
+
"outputs": [],
|
901 |
+
"source": [
|
902 |
+
"wandb.finish()"
|
903 |
+
]
|
904 |
+
},
|
905 |
+
{
|
906 |
+
"cell_type": "code",
|
907 |
+
"execution_count": null,
|
908 |
+
"metadata": {
|
909 |
+
"gather": {
|
910 |
+
"logged": 1706486541928
|
911 |
+
}
|
912 |
+
},
|
913 |
+
"outputs": [],
|
914 |
+
"source": [
|
915 |
+
"variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
916 |
+
"tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
917 |
+
"tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
918 |
+
"sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
919 |
+
"\n",
|
920 |
+
"model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
921 |
+
" variant=variant,\n",
|
922 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
|
923 |
+
]
|
924 |
+
}
|
925 |
+
],
|
926 |
+
"metadata": {
|
927 |
+
"datalore": {
|
928 |
+
"base_environment": "default",
|
929 |
+
"computation_mode": "JUPYTER",
|
930 |
+
"package_manager": "pip",
|
931 |
+
"packages": [
|
932 |
+
{
|
933 |
+
"name": "datasets",
|
934 |
+
"source": "PIP",
|
935 |
+
"version": "2.16.1"
|
936 |
+
},
|
937 |
+
{
|
938 |
+
"name": "torch",
|
939 |
+
"source": "PIP",
|
940 |
+
"version": "2.1.2"
|
941 |
+
},
|
942 |
+
{
|
943 |
+
"name": "accelerate",
|
944 |
+
"source": "PIP",
|
945 |
+
"version": "0.26.1"
|
946 |
+
}
|
947 |
+
],
|
948 |
+
"report_row_ids": [
|
949 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
950 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
951 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
952 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
953 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
954 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
955 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
956 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
957 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
958 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
959 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
960 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
961 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
962 |
+
],
|
963 |
+
"version": 3
|
964 |
+
},
|
965 |
+
"kernel_info": {
|
966 |
+
"name": "python38-azureml-pt-tf"
|
967 |
+
},
|
968 |
+
"kernelspec": {
|
969 |
+
"display_name": "Python 3.8 - Pytorch and Tensorflow",
|
970 |
+
"language": "python",
|
971 |
+
"name": "python38-azureml-pt-tf"
|
972 |
+
},
|
973 |
+
"language_info": {
|
974 |
+
"codemirror_mode": {
|
975 |
+
"name": "ipython",
|
976 |
+
"version": 3
|
977 |
+
},
|
978 |
+
"file_extension": ".py",
|
979 |
+
"mimetype": "text/x-python",
|
980 |
+
"name": "python",
|
981 |
+
"nbconvert_exporter": "python",
|
982 |
+
"pygments_lexer": "ipython3",
|
983 |
+
"version": "3.8.5"
|
984 |
+
},
|
985 |
+
"microsoft": {
|
986 |
+
"host": {
|
987 |
+
"AzureML": {
|
988 |
+
"notebookHasBeenCompleted": true
|
989 |
+
}
|
990 |
+
},
|
991 |
+
"ms_spell_check": {
|
992 |
+
"ms_spell_check_language": "en"
|
993 |
+
}
|
994 |
+
},
|
995 |
+
"nteract": {
|
996 |
+
"version": "nteract-front-end@1.0.0"
|
997 |
+
}
|
998 |
+
},
|
999 |
+
"nbformat": 4,
|
1000 |
+
"nbformat_minor": 4
|
1001 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-29-4-40-54Z.ipynb
ADDED
@@ -0,0 +1,1073 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
8 |
+
"\n",
|
9 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"gather": {
|
17 |
+
"logged": 1706475754655
|
18 |
+
},
|
19 |
+
"nteract": {
|
20 |
+
"transient": {
|
21 |
+
"deleting": false
|
22 |
+
}
|
23 |
+
},
|
24 |
+
"tags": []
|
25 |
+
},
|
26 |
+
"outputs": [
|
27 |
+
{
|
28 |
+
"name": "stdout",
|
29 |
+
"output_type": "stream",
|
30 |
+
"text": [
|
31 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
32 |
+
"Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
|
33 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
|
34 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
|
35 |
+
"Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
|
36 |
+
"Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
|
37 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
|
38 |
+
"Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
|
39 |
+
"Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
|
40 |
+
"Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
|
41 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
|
42 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
|
43 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
|
44 |
+
"Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
|
45 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
|
46 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
|
47 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
|
48 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
|
49 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
50 |
+
]
|
51 |
+
}
|
52 |
+
],
|
53 |
+
"source": [
|
54 |
+
"%pip install accelerate -U"
|
55 |
+
]
|
56 |
+
},
|
57 |
+
{
|
58 |
+
"cell_type": "code",
|
59 |
+
"execution_count": 2,
|
60 |
+
"metadata": {
|
61 |
+
"nteract": {
|
62 |
+
"transient": {
|
63 |
+
"deleting": false
|
64 |
+
}
|
65 |
+
}
|
66 |
+
},
|
67 |
+
"outputs": [
|
68 |
+
{
|
69 |
+
"name": "stdout",
|
70 |
+
"output_type": "stream",
|
71 |
+
"text": [
|
72 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
73 |
+
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
|
74 |
+
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
|
75 |
+
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
|
76 |
+
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
|
77 |
+
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
|
78 |
+
"Requirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\n",
|
79 |
+
"Requirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\n",
|
80 |
+
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
|
81 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
|
82 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
|
83 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
|
84 |
+
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
|
85 |
+
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
|
86 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
|
87 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
|
88 |
+
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
|
89 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
|
90 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
|
91 |
+
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
|
92 |
+
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
|
93 |
+
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
|
94 |
+
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
|
95 |
+
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
|
96 |
+
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
|
97 |
+
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
|
98 |
+
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
|
99 |
+
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
|
100 |
+
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
|
101 |
+
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
|
102 |
+
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
|
103 |
+
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
|
104 |
+
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
|
105 |
+
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
|
106 |
+
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
|
107 |
+
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
|
108 |
+
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
|
109 |
+
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
|
110 |
+
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
|
111 |
+
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
|
112 |
+
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
|
113 |
+
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
|
114 |
+
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
|
115 |
+
"Requirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\n",
|
116 |
+
"Requirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\n",
|
117 |
+
"Requirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\n",
|
118 |
+
"Requirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\n",
|
119 |
+
"Requirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\n",
|
120 |
+
"Requirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\n",
|
121 |
+
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
|
122 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
123 |
+
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
124 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
|
125 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
|
126 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
127 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
|
128 |
+
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
|
129 |
+
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
|
130 |
+
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
|
131 |
+
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
|
132 |
+
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
|
133 |
+
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
|
134 |
+
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
|
135 |
+
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
|
136 |
+
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
|
137 |
+
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
|
138 |
+
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
|
139 |
+
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
|
140 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
|
141 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
|
142 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
|
143 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
|
144 |
+
"Requirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\n",
|
145 |
+
"Requirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\n",
|
146 |
+
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
|
147 |
+
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
148 |
+
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
149 |
+
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
|
150 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
|
151 |
+
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
|
152 |
+
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
|
153 |
+
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
|
154 |
+
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
|
155 |
+
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
|
156 |
+
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
|
157 |
+
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
|
158 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
159 |
+
]
|
160 |
+
}
|
161 |
+
],
|
162 |
+
"source": [
|
163 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
164 |
+
]
|
165 |
+
},
|
166 |
+
{
|
167 |
+
"cell_type": "code",
|
168 |
+
"execution_count": 4,
|
169 |
+
"metadata": {
|
170 |
+
"datalore": {
|
171 |
+
"hide_input_from_viewers": false,
|
172 |
+
"hide_output_from_viewers": false,
|
173 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
174 |
+
"report_properties": {
|
175 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
176 |
+
},
|
177 |
+
"type": "CODE"
|
178 |
+
},
|
179 |
+
"gather": {
|
180 |
+
"logged": 1706486372154
|
181 |
+
},
|
182 |
+
"tags": []
|
183 |
+
},
|
184 |
+
"outputs": [
|
185 |
+
{
|
186 |
+
"name": "stderr",
|
187 |
+
"output_type": "stream",
|
188 |
+
"text": [
|
189 |
+
"[codecarbon INFO @ 04:20:20] [setup] RAM Tracking...\n",
|
190 |
+
"[codecarbon INFO @ 04:20:20] [setup] GPU Tracking...\n",
|
191 |
+
"[codecarbon INFO @ 04:20:20] Tracking Nvidia GPU via pynvml\n",
|
192 |
+
"[codecarbon INFO @ 04:20:20] [setup] CPU Tracking...\n",
|
193 |
+
"[codecarbon WARNING @ 04:20:20] No CPU tracking mode found. Falling back on CPU constant mode.\n",
|
194 |
+
"[codecarbon WARNING @ 04:20:21] We saw that you have a Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz but we don't know it. Please contact us.\n",
|
195 |
+
"[codecarbon INFO @ 04:20:21] CPU Model on constant consumption mode: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
|
196 |
+
"[codecarbon INFO @ 04:20:21] >>> Tracker's metadata:\n",
|
197 |
+
"[codecarbon INFO @ 04:20:21] Platform system: Linux-5.15.0-1040-azure-x86_64-with-glibc2.10\n",
|
198 |
+
"[codecarbon INFO @ 04:20:21] Python version: 3.8.5\n",
|
199 |
+
"[codecarbon INFO @ 04:20:21] CodeCarbon version: 2.3.3\n",
|
200 |
+
"[codecarbon INFO @ 04:20:21] Available RAM : 440.883 GB\n",
|
201 |
+
"[codecarbon INFO @ 04:20:21] CPU count: 24\n",
|
202 |
+
"[codecarbon INFO @ 04:20:21] CPU model: Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz\n",
|
203 |
+
"[codecarbon INFO @ 04:20:21] GPU count: 4\n",
|
204 |
+
"[codecarbon INFO @ 04:20:21] GPU model: 4 x Tesla V100-PCIE-16GB\n",
|
205 |
+
"[codecarbon WARNING @ 04:20:21] Cloud provider 'azure' do not publish electricity carbon intensity. Using country value instead.\n"
|
206 |
+
]
|
207 |
+
}
|
208 |
+
],
|
209 |
+
"source": [
|
210 |
+
"import pandas as pd\n",
|
211 |
+
"import numpy as np\n",
|
212 |
+
"import torch\n",
|
213 |
+
"import os\n",
|
214 |
+
"from typing import List, Union\n",
|
215 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline\n",
|
216 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
217 |
+
"import shap\n",
|
218 |
+
"import wandb\n",
|
219 |
+
"import evaluate\n",
|
220 |
+
"from codecarbon import EmissionsTracker\n",
|
221 |
+
"import logging\n",
|
222 |
+
"\n",
|
223 |
+
"wandb.finish()\n",
|
224 |
+
"\n",
|
225 |
+
"logging.getLogger('codecarbon').propagate = False\n",
|
226 |
+
"\n",
|
227 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
228 |
+
"tracker = EmissionsTracker()\n",
|
229 |
+
"\n",
|
230 |
+
"%load_ext watermark"
|
231 |
+
]
|
232 |
+
},
|
233 |
+
{
|
234 |
+
"cell_type": "code",
|
235 |
+
"execution_count": 5,
|
236 |
+
"metadata": {
|
237 |
+
"collapsed": false,
|
238 |
+
"gather": {
|
239 |
+
"logged": 1706486372304
|
240 |
+
},
|
241 |
+
"jupyter": {
|
242 |
+
"outputs_hidden": false
|
243 |
+
}
|
244 |
+
},
|
245 |
+
"outputs": [],
|
246 |
+
"source": [
|
247 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
248 |
+
"\n",
|
249 |
+
"SEED: int = 42\n",
|
250 |
+
"\n",
|
251 |
+
"BATCH_SIZE: int = 32\n",
|
252 |
+
"EPOCHS: int = 5\n",
|
253 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
254 |
+
"\n",
|
255 |
+
"# WandB configuration\n",
|
256 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
257 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
258 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
259 |
+
]
|
260 |
+
},
|
261 |
+
{
|
262 |
+
"cell_type": "code",
|
263 |
+
"execution_count": 6,
|
264 |
+
"metadata": {
|
265 |
+
"collapsed": false,
|
266 |
+
"jupyter": {
|
267 |
+
"outputs_hidden": false
|
268 |
+
}
|
269 |
+
},
|
270 |
+
"outputs": [
|
271 |
+
{
|
272 |
+
"name": "stdout",
|
273 |
+
"output_type": "stream",
|
274 |
+
"text": [
|
275 |
+
"re : 2.2.1\n",
|
276 |
+
"pandas : 2.0.2\n",
|
277 |
+
"evaluate: 0.4.1\n",
|
278 |
+
"logging : 0.5.1.2\n",
|
279 |
+
"torch : 1.12.0\n",
|
280 |
+
"shap : 0.44.1\n",
|
281 |
+
"wandb : 0.16.2\n",
|
282 |
+
"numpy : 1.23.5\n",
|
283 |
+
"\n"
|
284 |
+
]
|
285 |
+
}
|
286 |
+
],
|
287 |
+
"source": [
|
288 |
+
"%watermark --iversion"
|
289 |
+
]
|
290 |
+
},
|
291 |
+
{
|
292 |
+
"cell_type": "code",
|
293 |
+
"execution_count": 7,
|
294 |
+
"metadata": {
|
295 |
+
"datalore": {
|
296 |
+
"hide_input_from_viewers": true,
|
297 |
+
"hide_output_from_viewers": true,
|
298 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
299 |
+
"type": "CODE"
|
300 |
+
}
|
301 |
+
},
|
302 |
+
"outputs": [
|
303 |
+
{
|
304 |
+
"name": "stdout",
|
305 |
+
"output_type": "stream",
|
306 |
+
"text": [
|
307 |
+
"/bin/bash: /anaconda/envs/azureml_py38_PT_TF/lib/libtinfo.so.6: no version information available (required by /bin/bash)\n",
|
308 |
+
"Mon Jan 29 04:20:46 2024 \n",
|
309 |
+
"+---------------------------------------------------------------------------------------+\n",
|
310 |
+
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
|
311 |
+
"|-----------------------------------------+----------------------+----------------------+\n",
|
312 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
313 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
314 |
+
"| | | MIG M. |\n",
|
315 |
+
"|=========================================+======================+======================|\n",
|
316 |
+
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
|
317 |
+
"| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
318 |
+
"| | | N/A |\n",
|
319 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
320 |
+
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
|
321 |
+
"| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
322 |
+
"| | | N/A |\n",
|
323 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
324 |
+
"| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\n",
|
325 |
+
"| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
326 |
+
"| | | N/A |\n",
|
327 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
328 |
+
"| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\n",
|
329 |
+
"| N/A 27C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\n",
|
330 |
+
"| | | N/A |\n",
|
331 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
332 |
+
" \n",
|
333 |
+
"+---------------------------------------------------------------------------------------+\n",
|
334 |
+
"| Processes: |\n",
|
335 |
+
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
336 |
+
"| ID ID Usage |\n",
|
337 |
+
"|=======================================================================================|\n",
|
338 |
+
"| No running processes found |\n",
|
339 |
+
"+---------------------------------------------------------------------------------------+\n"
|
340 |
+
]
|
341 |
+
}
|
342 |
+
],
|
343 |
+
"source": [
|
344 |
+
"!nvidia-smi"
|
345 |
+
]
|
346 |
+
},
|
347 |
+
{
|
348 |
+
"cell_type": "markdown",
|
349 |
+
"metadata": {
|
350 |
+
"datalore": {
|
351 |
+
"hide_input_from_viewers": false,
|
352 |
+
"hide_output_from_viewers": false,
|
353 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
354 |
+
"report_properties": {
|
355 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
356 |
+
},
|
357 |
+
"type": "MD"
|
358 |
+
}
|
359 |
+
},
|
360 |
+
"source": [
|
361 |
+
"## Loading the data set"
|
362 |
+
]
|
363 |
+
},
|
364 |
+
{
|
365 |
+
"cell_type": "code",
|
366 |
+
"execution_count": 8,
|
367 |
+
"metadata": {
|
368 |
+
"collapsed": false,
|
369 |
+
"gather": {
|
370 |
+
"logged": 1706486373931
|
371 |
+
},
|
372 |
+
"jupyter": {
|
373 |
+
"outputs_hidden": false
|
374 |
+
}
|
375 |
+
},
|
376 |
+
"outputs": [],
|
377 |
+
"source": [
|
378 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
379 |
+
]
|
380 |
+
},
|
381 |
+
{
|
382 |
+
"cell_type": "code",
|
383 |
+
"execution_count": 9,
|
384 |
+
"metadata": {
|
385 |
+
"collapsed": false,
|
386 |
+
"gather": {
|
387 |
+
"logged": 1706486374218
|
388 |
+
},
|
389 |
+
"jupyter": {
|
390 |
+
"outputs_hidden": false,
|
391 |
+
"source_hidden": false
|
392 |
+
},
|
393 |
+
"nteract": {
|
394 |
+
"transient": {
|
395 |
+
"deleting": false
|
396 |
+
}
|
397 |
+
}
|
398 |
+
},
|
399 |
+
"outputs": [
|
400 |
+
{
|
401 |
+
"data": {
|
402 |
+
"text/plain": [
|
403 |
+
"DatasetDict({\n",
|
404 |
+
" train: Dataset({\n",
|
405 |
+
" features: ['id', 'text', 'label'],\n",
|
406 |
+
" num_rows: 1270444\n",
|
407 |
+
" })\n",
|
408 |
+
" test: Dataset({\n",
|
409 |
+
" features: ['id', 'text', 'label'],\n",
|
410 |
+
" num_rows: 272238\n",
|
411 |
+
" })\n",
|
412 |
+
" val: Dataset({\n",
|
413 |
+
" features: ['id', 'text', 'label'],\n",
|
414 |
+
" num_rows: 272238\n",
|
415 |
+
" })\n",
|
416 |
+
"})"
|
417 |
+
]
|
418 |
+
},
|
419 |
+
"execution_count": 9,
|
420 |
+
"metadata": {},
|
421 |
+
"output_type": "execute_result"
|
422 |
+
}
|
423 |
+
],
|
424 |
+
"source": [
|
425 |
+
"dataset"
|
426 |
+
]
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"cell_type": "code",
|
430 |
+
"execution_count": 10,
|
431 |
+
"metadata": {
|
432 |
+
"gather": {
|
433 |
+
"logged": 1706486374480
|
434 |
+
}
|
435 |
+
},
|
436 |
+
"outputs": [],
|
437 |
+
"source": [
|
438 |
+
"SUBSAMPLING = 1.0\n",
|
439 |
+
"\n",
|
440 |
+
"if SUBSAMPLING < 1:\n",
|
441 |
+
" _ = DatasetDict()\n",
|
442 |
+
" for each in dataset.keys():\n",
|
443 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
444 |
+
"\n",
|
445 |
+
" dataset = _"
|
446 |
+
]
|
447 |
+
},
|
448 |
+
{
|
449 |
+
"cell_type": "markdown",
|
450 |
+
"metadata": {},
|
451 |
+
"source": [
|
452 |
+
"## Tokenisation and encoding"
|
453 |
+
]
|
454 |
+
},
|
455 |
+
{
|
456 |
+
"cell_type": "code",
|
457 |
+
"execution_count": 11,
|
458 |
+
"metadata": {
|
459 |
+
"gather": {
|
460 |
+
"logged": 1706486375030
|
461 |
+
}
|
462 |
+
},
|
463 |
+
"outputs": [],
|
464 |
+
"source": [
|
465 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
466 |
+
" return ds_enc"
|
467 |
+
]
|
468 |
+
},
|
469 |
+
{
|
470 |
+
"cell_type": "markdown",
|
471 |
+
"metadata": {},
|
472 |
+
"source": [
|
473 |
+
"## Evaluation metrics"
|
474 |
+
]
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"cell_type": "code",
|
478 |
+
"execution_count": 12,
|
479 |
+
"metadata": {
|
480 |
+
"gather": {
|
481 |
+
"logged": 1706486375197
|
482 |
+
}
|
483 |
+
},
|
484 |
+
"outputs": [],
|
485 |
+
"source": [
|
486 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
487 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
488 |
+
"f1 = evaluate.load(\"f1\")"
|
489 |
+
]
|
490 |
+
},
|
491 |
+
{
|
492 |
+
"cell_type": "code",
|
493 |
+
"execution_count": 13,
|
494 |
+
"metadata": {
|
495 |
+
"gather": {
|
496 |
+
"logged": 1706486375361
|
497 |
+
}
|
498 |
+
},
|
499 |
+
"outputs": [],
|
500 |
+
"source": [
|
501 |
+
"def compute_metrics(eval_pred):\n",
|
502 |
+
" predictions, labels = eval_pred\n",
|
503 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
504 |
+
" return {\n",
|
505 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
506 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
507 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
508 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
509 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
510 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
511 |
+
" }"
|
512 |
+
]
|
513 |
+
},
|
514 |
+
{
|
515 |
+
"cell_type": "markdown",
|
516 |
+
"metadata": {},
|
517 |
+
"source": [
|
518 |
+
"## Training"
|
519 |
+
]
|
520 |
+
},
|
521 |
+
{
|
522 |
+
"cell_type": "markdown",
|
523 |
+
"metadata": {},
|
524 |
+
"source": [
|
525 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
526 |
+
]
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"cell_type": "code",
|
530 |
+
"execution_count": 14,
|
531 |
+
"metadata": {
|
532 |
+
"gather": {
|
533 |
+
"logged": 1706486375569
|
534 |
+
}
|
535 |
+
},
|
536 |
+
"outputs": [],
|
537 |
+
"source": [
|
538 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
539 |
+
]
|
540 |
+
},
|
541 |
+
{
|
542 |
+
"cell_type": "code",
|
543 |
+
"execution_count": 15,
|
544 |
+
"metadata": {
|
545 |
+
"gather": {
|
546 |
+
"logged": 1706486433708
|
547 |
+
}
|
548 |
+
},
|
549 |
+
"outputs": [
|
550 |
+
{
|
551 |
+
"name": "stderr",
|
552 |
+
"output_type": "stream",
|
553 |
+
"text": [
|
554 |
+
"Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 1270444/1270444 [08:09<00:00, 2595.90 examples/s]\n",
|
555 |
+
"Map: 100%|████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272238/272238 [01:45<00:00, 2585.25 examples/s]\n",
|
556 |
+
"Map: 100%|██████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████���█████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████████| 272238/272238 [01:44<00:00, 2605.66 examples/s]\n"
|
557 |
+
]
|
558 |
+
}
|
559 |
+
],
|
560 |
+
"source": [
|
561 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
562 |
+
"\n",
|
563 |
+
"cols = dataset[\"train\"].column_names\n",
|
564 |
+
"cols.remove(\"label\")\n",
|
565 |
+
"ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True), batched=True, remove_columns=cols)\n"
|
566 |
+
]
|
567 |
+
},
|
568 |
+
{
|
569 |
+
"cell_type": "code",
|
570 |
+
"execution_count": 16,
|
571 |
+
"metadata": {},
|
572 |
+
"outputs": [
|
573 |
+
{
|
574 |
+
"name": "stderr",
|
575 |
+
"output_type": "stream",
|
576 |
+
"text": [
|
577 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
|
578 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
579 |
+
]
|
580 |
+
}
|
581 |
+
],
|
582 |
+
"source": [
|
583 |
+
"\n",
|
584 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
585 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
586 |
+
" id2label=label_map, \n",
|
587 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
588 |
+
"\n",
|
589 |
+
"args = TrainingArguments(\n",
|
590 |
+
" output_dir=\"vaers\",\n",
|
591 |
+
" evaluation_strategy=\"epoch\",\n",
|
592 |
+
" save_strategy=\"epoch\",\n",
|
593 |
+
" learning_rate=2e-5,\n",
|
594 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
595 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
596 |
+
" num_train_epochs=EPOCHS,\n",
|
597 |
+
" weight_decay=.01,\n",
|
598 |
+
" logging_steps=1,\n",
|
599 |
+
" load_best_model_at_end=True,\n",
|
600 |
+
" run_name=f\"daedra-training\",\n",
|
601 |
+
" report_to=[\"wandb\"])\n",
|
602 |
+
"\n",
|
603 |
+
"trainer = Trainer(\n",
|
604 |
+
" model=model,\n",
|
605 |
+
" args=args,\n",
|
606 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
607 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
608 |
+
" tokenizer=tokenizer,\n",
|
609 |
+
" compute_metrics=compute_metrics)"
|
610 |
+
]
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"cell_type": "code",
|
614 |
+
"execution_count": 17,
|
615 |
+
"metadata": {
|
616 |
+
"gather": {
|
617 |
+
"logged": 1706486444806
|
618 |
+
}
|
619 |
+
},
|
620 |
+
"outputs": [
|
621 |
+
{
|
622 |
+
"name": "stderr",
|
623 |
+
"output_type": "stream",
|
624 |
+
"text": [
|
625 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33mchrisvoncsefalvay\u001b[0m. Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n",
|
626 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
627 |
+
]
|
628 |
+
},
|
629 |
+
{
|
630 |
+
"data": {
|
631 |
+
"text/html": [
|
632 |
+
"Tracking run with wandb version 0.16.2"
|
633 |
+
],
|
634 |
+
"text/plain": [
|
635 |
+
"<IPython.core.display.HTML object>"
|
636 |
+
]
|
637 |
+
},
|
638 |
+
"metadata": {},
|
639 |
+
"output_type": "display_data"
|
640 |
+
},
|
641 |
+
{
|
642 |
+
"data": {
|
643 |
+
"text/html": [
|
644 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_043232-tl59png2</code>"
|
645 |
+
],
|
646 |
+
"text/plain": [
|
647 |
+
"<IPython.core.display.HTML object>"
|
648 |
+
]
|
649 |
+
},
|
650 |
+
"metadata": {},
|
651 |
+
"output_type": "display_data"
|
652 |
+
},
|
653 |
+
{
|
654 |
+
"data": {
|
655 |
+
"text/html": [
|
656 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
657 |
+
],
|
658 |
+
"text/plain": [
|
659 |
+
"<IPython.core.display.HTML object>"
|
660 |
+
]
|
661 |
+
},
|
662 |
+
"metadata": {},
|
663 |
+
"output_type": "display_data"
|
664 |
+
},
|
665 |
+
{
|
666 |
+
"data": {
|
667 |
+
"text/html": [
|
668 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
669 |
+
],
|
670 |
+
"text/plain": [
|
671 |
+
"<IPython.core.display.HTML object>"
|
672 |
+
]
|
673 |
+
},
|
674 |
+
"metadata": {},
|
675 |
+
"output_type": "display_data"
|
676 |
+
},
|
677 |
+
{
|
678 |
+
"data": {
|
679 |
+
"text/html": [
|
680 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2</a>"
|
681 |
+
],
|
682 |
+
"text/plain": [
|
683 |
+
"<IPython.core.display.HTML object>"
|
684 |
+
]
|
685 |
+
},
|
686 |
+
"metadata": {},
|
687 |
+
"output_type": "display_data"
|
688 |
+
},
|
689 |
+
{
|
690 |
+
"data": {
|
691 |
+
"text/html": [
|
692 |
+
"Finishing last run (ID:tl59png2) before initializing another..."
|
693 |
+
],
|
694 |
+
"text/plain": [
|
695 |
+
"<IPython.core.display.HTML object>"
|
696 |
+
]
|
697 |
+
},
|
698 |
+
"metadata": {},
|
699 |
+
"output_type": "display_data"
|
700 |
+
},
|
701 |
+
{
|
702 |
+
"data": {
|
703 |
+
"text/html": [
|
704 |
+
" View run <strong style=\"color:#cdcd00\">daedra_training_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/tl59png2</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v0' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v0</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
705 |
+
],
|
706 |
+
"text/plain": [
|
707 |
+
"<IPython.core.display.HTML object>"
|
708 |
+
]
|
709 |
+
},
|
710 |
+
"metadata": {},
|
711 |
+
"output_type": "display_data"
|
712 |
+
},
|
713 |
+
{
|
714 |
+
"data": {
|
715 |
+
"text/html": [
|
716 |
+
"Find logs at: <code>./wandb/run-20240129_043232-tl59png2/logs</code>"
|
717 |
+
],
|
718 |
+
"text/plain": [
|
719 |
+
"<IPython.core.display.HTML object>"
|
720 |
+
]
|
721 |
+
},
|
722 |
+
"metadata": {},
|
723 |
+
"output_type": "display_data"
|
724 |
+
},
|
725 |
+
{
|
726 |
+
"data": {
|
727 |
+
"text/html": [
|
728 |
+
"Successfully finished last run (ID:tl59png2). Initializing new run:<br/>"
|
729 |
+
],
|
730 |
+
"text/plain": [
|
731 |
+
"<IPython.core.display.HTML object>"
|
732 |
+
]
|
733 |
+
},
|
734 |
+
"metadata": {},
|
735 |
+
"output_type": "display_data"
|
736 |
+
},
|
737 |
+
{
|
738 |
+
"data": {
|
739 |
+
"text/html": [
|
740 |
+
"Tracking run with wandb version 0.16.2"
|
741 |
+
],
|
742 |
+
"text/plain": [
|
743 |
+
"<IPython.core.display.HTML object>"
|
744 |
+
]
|
745 |
+
},
|
746 |
+
"metadata": {},
|
747 |
+
"output_type": "display_data"
|
748 |
+
},
|
749 |
+
{
|
750 |
+
"data": {
|
751 |
+
"text/html": [
|
752 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_043243-x8j2xw0x</code>"
|
753 |
+
],
|
754 |
+
"text/plain": [
|
755 |
+
"<IPython.core.display.HTML object>"
|
756 |
+
]
|
757 |
+
},
|
758 |
+
"metadata": {},
|
759 |
+
"output_type": "display_data"
|
760 |
+
},
|
761 |
+
{
|
762 |
+
"data": {
|
763 |
+
"text/html": [
|
764 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x' target=\"_blank\">daedra_training_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
765 |
+
],
|
766 |
+
"text/plain": [
|
767 |
+
"<IPython.core.display.HTML object>"
|
768 |
+
]
|
769 |
+
},
|
770 |
+
"metadata": {},
|
771 |
+
"output_type": "display_data"
|
772 |
+
},
|
773 |
+
{
|
774 |
+
"data": {
|
775 |
+
"text/html": [
|
776 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
777 |
+
],
|
778 |
+
"text/plain": [
|
779 |
+
"<IPython.core.display.HTML object>"
|
780 |
+
]
|
781 |
+
},
|
782 |
+
"metadata": {},
|
783 |
+
"output_type": "display_data"
|
784 |
+
},
|
785 |
+
{
|
786 |
+
"data": {
|
787 |
+
"text/html": [
|
788 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x</a>"
|
789 |
+
],
|
790 |
+
"text/plain": [
|
791 |
+
"<IPython.core.display.HTML object>"
|
792 |
+
]
|
793 |
+
},
|
794 |
+
"metadata": {},
|
795 |
+
"output_type": "display_data"
|
796 |
+
},
|
797 |
+
{
|
798 |
+
"data": {
|
799 |
+
"text/html": [
|
800 |
+
"<button onClick=\"this.nextSibling.style.display='block';this.style.display='none';\">Display W&B run</button><iframe src='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/x8j2xw0x?jupyter=true' style='border:none;width:100%;height:420px;display:none;'></iframe>"
|
801 |
+
],
|
802 |
+
"text/plain": [
|
803 |
+
"<wandb.sdk.wandb_run.Run at 0x7ffa3d0e9bb0>"
|
804 |
+
]
|
805 |
+
},
|
806 |
+
"execution_count": 17,
|
807 |
+
"metadata": {},
|
808 |
+
"output_type": "execute_result"
|
809 |
+
}
|
810 |
+
],
|
811 |
+
"source": [
|
812 |
+
"if SUBSAMPLING != 1.0:\n",
|
813 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
814 |
+
"else:\n",
|
815 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
816 |
+
"\n",
|
817 |
+
"wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
818 |
+
"wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
819 |
+
" \n",
|
820 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)"
|
821 |
+
]
|
822 |
+
},
|
823 |
+
{
|
824 |
+
"cell_type": "code",
|
825 |
+
"execution_count": 18,
|
826 |
+
"metadata": {
|
827 |
+
"gather": {
|
828 |
+
"logged": 1706486541798
|
829 |
+
}
|
830 |
+
},
|
831 |
+
"outputs": [
|
832 |
+
{
|
833 |
+
"name": "stderr",
|
834 |
+
"output_type": "stream",
|
835 |
+
"text": [
|
836 |
+
"Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
837 |
+
]
|
838 |
+
},
|
839 |
+
{
|
840 |
+
"data": {
|
841 |
+
"text/html": [
|
842 |
+
"\n",
|
843 |
+
" <div>\n",
|
844 |
+
" \n",
|
845 |
+
" <progress value='394' max='49630' style='width:300px; height:20px; vertical-align: middle;'></progress>\n",
|
846 |
+
" [ 394/49630 04:13 < 8:50:52, 1.55 it/s, Epoch 0.04/5]\n",
|
847 |
+
" </div>\n",
|
848 |
+
" <table border=\"1\" class=\"dataframe\">\n",
|
849 |
+
" <thead>\n",
|
850 |
+
" <tr style=\"text-align: left;\">\n",
|
851 |
+
" <th>Epoch</th>\n",
|
852 |
+
" <th>Training Loss</th>\n",
|
853 |
+
" <th>Validation Loss</th>\n",
|
854 |
+
" </tr>\n",
|
855 |
+
" </thead>\n",
|
856 |
+
" <tbody>\n",
|
857 |
+
" </tbody>\n",
|
858 |
+
"</table><p>"
|
859 |
+
],
|
860 |
+
"text/plain": [
|
861 |
+
"<IPython.core.display.HTML object>"
|
862 |
+
]
|
863 |
+
},
|
864 |
+
"metadata": {},
|
865 |
+
"output_type": "display_data"
|
866 |
+
},
|
867 |
+
{
|
868 |
+
"name": "stderr",
|
869 |
+
"output_type": "stream",
|
870 |
+
"text": [
|
871 |
+
"[codecarbon INFO @ 04:33:12] Energy consumed for RAM : 0.000689 kWh. RAM Power : 165.33123922348022 W\n",
|
872 |
+
"[codecarbon INFO @ 04:33:12] Energy consumed for all GPUs : 0.001450 kWh. Total GPU Power : 347.66451200921796 W\n",
|
873 |
+
"[codecarbon INFO @ 04:33:12] Energy consumed for all CPUs : 0.000177 kWh. Total CPU Power : 42.5 W\n",
|
874 |
+
"[codecarbon INFO @ 04:33:12] 0.002317 kWh of electricity used since the beginning.\n",
|
875 |
+
"[codecarbon INFO @ 04:33:27] Energy consumed for RAM : 0.001378 kWh. RAM Power : 165.33123922348022 W\n",
|
876 |
+
"[codecarbon INFO @ 04:33:27] Energy consumed for all GPUs : 0.004012 kWh. Total GPU Power : 615.4556826768763 W\n",
|
877 |
+
"[codecarbon INFO @ 04:33:27] Energy consumed for all CPUs : 0.000355 kWh. Total CPU Power : 42.5 W\n",
|
878 |
+
"[codecarbon INFO @ 04:33:27] 0.005745 kWh of electricity used since the beginning.\n",
|
879 |
+
"[codecarbon INFO @ 04:33:42] Energy consumed for RAM : 0.002066 kWh. RAM Power : 165.33123922348022 W\n",
|
880 |
+
"[codecarbon INFO @ 04:33:42] Energy consumed for all GPUs : 0.006596 kWh. Total GPU Power : 620.9110211178034 W\n",
|
881 |
+
"[codecarbon INFO @ 04:33:42] Energy consumed for all CPUs : 0.000532 kWh. Total CPU Power : 42.5 W\n",
|
882 |
+
"[codecarbon INFO @ 04:33:42] 0.009194 kWh of electricity used since the beginning.\n",
|
883 |
+
"[codecarbon INFO @ 04:33:57] Energy consumed for RAM : 0.002754 kWh. RAM Power : 165.33123922348022 W\n",
|
884 |
+
"[codecarbon INFO @ 04:33:57] Energy consumed for all GPUs : 0.009183 kWh. Total GPU Power : 621.1270289526989 W\n",
|
885 |
+
"[codecarbon INFO @ 04:33:57] Energy consumed for all CPUs : 0.000709 kWh. Total CPU Power : 42.5 W\n",
|
886 |
+
"[codecarbon INFO @ 04:33:57] 0.012645 kWh of electricity used since the beginning.\n",
|
887 |
+
"[codecarbon INFO @ 04:34:12] Energy consumed for RAM : 0.003442 kWh. RAM Power : 165.33123922348022 W\n",
|
888 |
+
"[codecarbon INFO @ 04:34:12] Energy consumed for all GPUs : 0.011798 kWh. Total GPU Power : 628.3875606622404 W\n",
|
889 |
+
"[codecarbon INFO @ 04:34:12] Energy consumed for all CPUs : 0.000886 kWh. Total CPU Power : 42.5 W\n",
|
890 |
+
"[codecarbon INFO @ 04:34:12] 0.016125 kWh of electricity used since the beginning.\n",
|
891 |
+
"[codecarbon INFO @ 04:34:27] Energy consumed for RAM : 0.004130 kWh. RAM Power : 165.33123922348022 W\n",
|
892 |
+
"[codecarbon INFO @ 04:34:27] Energy consumed for all GPUs : 0.014431 kWh. Total GPU Power : 632.4054645127197 W\n",
|
893 |
+
"[codecarbon INFO @ 04:34:27] Energy consumed for all CPUs : 0.001063 kWh. Total CPU Power : 42.5 W\n",
|
894 |
+
"[codecarbon INFO @ 04:34:27] 0.019623 kWh of electricity used since the beginning.\n",
|
895 |
+
"[codecarbon INFO @ 04:34:42] Energy consumed for RAM : 0.004818 kWh. RAM Power : 165.33123922348022 W\n",
|
896 |
+
"[codecarbon INFO @ 04:34:42] Energy consumed for all GPUs : 0.017064 kWh. Total GPU Power : 632.6571124342939 W\n",
|
897 |
+
"[codecarbon INFO @ 04:34:42] Energy consumed for all CPUs : 0.001240 kWh. Total CPU Power : 42.5 W\n",
|
898 |
+
"[codecarbon INFO @ 04:34:42] 0.023122 kWh of electricity used since the beginning.\n",
|
899 |
+
"[codecarbon INFO @ 04:34:57] Energy consumed for RAM : 0.005506 kWh. RAM Power : 165.33123922348022 W\n",
|
900 |
+
"[codecarbon INFO @ 04:34:57] Energy consumed for all GPUs : 0.019707 kWh. Total GPU Power : 634.7921879339333 W\n",
|
901 |
+
"[codecarbon INFO @ 04:34:57] Energy consumed for all CPUs : 0.001417 kWh. Total CPU Power : 42.5 W\n",
|
902 |
+
"[codecarbon INFO @ 04:34:57] 0.026631 kWh of electricity used since the beginning.\n",
|
903 |
+
"[codecarbon INFO @ 04:35:12] Energy consumed for RAM : 0.006194 kWh. RAM Power : 165.33123922348022 W\n",
|
904 |
+
"[codecarbon INFO @ 04:35:12] Energy consumed for all GPUs : 0.022334 kWh. Total GPU Power : 630.3609394863598 W\n",
|
905 |
+
"[codecarbon INFO @ 04:35:12] Energy consumed for all CPUs : 0.001594 kWh. Total CPU Power : 42.5 W\n",
|
906 |
+
"[codecarbon INFO @ 04:35:12] 0.030123 kWh of electricity used since the beginning.\n",
|
907 |
+
"[codecarbon INFO @ 04:35:27] Energy consumed for RAM : 0.006882 kWh. RAM Power : 165.33123922348022 W\n",
|
908 |
+
"[codecarbon INFO @ 04:35:27] Energy consumed for all GPUs : 0.024956 kWh. Total GPU Power : 630.704729336156 W\n",
|
909 |
+
"[codecarbon INFO @ 04:35:27] Energy consumed for all CPUs : 0.001771 kWh. Total CPU Power : 42.5 W\n",
|
910 |
+
"[codecarbon INFO @ 04:35:27] 0.033609 kWh of electricity used since the beginning.\n",
|
911 |
+
"[codecarbon INFO @ 04:35:42] Energy consumed for RAM : 0.007570 kWh. RAM Power : 165.33123922348022 W\n",
|
912 |
+
"[codecarbon INFO @ 04:35:42] Energy consumed for all GPUs : 0.027604 kWh. Total GPU Power : 636.1545465788125 W\n",
|
913 |
+
"[codecarbon INFO @ 04:35:42] Energy consumed for all CPUs : 0.001948 kWh. Total CPU Power : 42.5 W\n",
|
914 |
+
"[codecarbon INFO @ 04:35:42] 0.037121 kWh of electricity used since the beginning.\n",
|
915 |
+
"[codecarbon INFO @ 04:35:57] Energy consumed for RAM : 0.008258 kWh. RAM Power : 165.33123922348022 W\n",
|
916 |
+
"[codecarbon INFO @ 04:35:57] Energy consumed for all GPUs : 0.030255 kWh. Total GPU Power : 636.9769106141198 W\n",
|
917 |
+
"[codecarbon INFO @ 04:35:57] Energy consumed for all CPUs : 0.002125 kWh. Total CPU Power : 42.5 W\n",
|
918 |
+
"[codecarbon INFO @ 04:35:57] 0.040638 kWh of electricity used since the beginning.\n",
|
919 |
+
"[codecarbon INFO @ 04:36:12] Energy consumed for RAM : 0.008946 kWh. RAM Power : 165.33123922348022 W\n",
|
920 |
+
"[codecarbon INFO @ 04:36:12] Energy consumed for all GPUs : 0.032913 kWh. Total GPU Power : 638.3412890613937 W\n",
|
921 |
+
"[codecarbon INFO @ 04:36:12] Energy consumed for all CPUs : 0.002302 kWh. Total CPU Power : 42.5 W\n",
|
922 |
+
"[codecarbon INFO @ 04:36:12] 0.044161 kWh of electricity used since the beginning.\n",
|
923 |
+
"[codecarbon INFO @ 04:36:27] Energy consumed for RAM : 0.009634 kWh. RAM Power : 165.33123922348022 W\n",
|
924 |
+
"[codecarbon INFO @ 04:36:27] Energy consumed for all GPUs : 0.035515 kWh. Total GPU Power : 625.0502398771333 W\n",
|
925 |
+
"[codecarbon INFO @ 04:36:27] Energy consumed for all CPUs : 0.002479 kWh. Total CPU Power : 42.5 W\n",
|
926 |
+
"[codecarbon INFO @ 04:36:27] 0.047628 kWh of electricity used since the beginning.\n",
|
927 |
+
"[codecarbon INFO @ 04:36:42] Energy consumed for RAM : 0.010322 kWh. RAM Power : 165.33123922348022 W\n",
|
928 |
+
"[codecarbon INFO @ 04:36:42] Energy consumed for all GPUs : 0.038183 kWh. Total GPU Power : 641.00719087638 W\n",
|
929 |
+
"[codecarbon INFO @ 04:36:42] Energy consumed for all CPUs : 0.002656 kWh. Total CPU Power : 42.5 W\n",
|
930 |
+
"[codecarbon INFO @ 04:36:42] 0.051162 kWh of electricity used since the beginning.\n",
|
931 |
+
"[codecarbon INFO @ 04:36:57] Energy consumed for RAM : 0.011010 kWh. RAM Power : 165.33123922348022 W\n",
|
932 |
+
"[codecarbon INFO @ 04:36:57] Energy consumed for all GPUs : 0.040821 kWh. Total GPU Power : 633.4817689949092 W\n",
|
933 |
+
"[codecarbon INFO @ 04:36:57] Energy consumed for all CPUs : 0.002834 kWh. Total CPU Power : 42.5 W\n",
|
934 |
+
"[codecarbon INFO @ 04:36:57] 0.054665 kWh of electricity used since the beginning.\n",
|
935 |
+
"[codecarbon INFO @ 04:37:12] Energy consumed for RAM : 0.011698 kWh. RAM Power : 165.33123922348022 W\n",
|
936 |
+
"[codecarbon INFO @ 04:37:12] Energy consumed for all GPUs : 0.043484 kWh. Total GPU Power : 639.8452880027475 W\n",
|
937 |
+
"[codecarbon INFO @ 04:37:12] Energy consumed for all CPUs : 0.003011 kWh. Total CPU Power : 42.5 W\n",
|
938 |
+
"[codecarbon INFO @ 04:37:12] 0.058193 kWh of electricity used since the beginning.\n"
|
939 |
+
]
|
940 |
+
}
|
941 |
+
],
|
942 |
+
"source": [
|
943 |
+
"tracker.start()\n",
|
944 |
+
"trainer.train()\n",
|
945 |
+
"tracker.stop()\n"
|
946 |
+
]
|
947 |
+
},
|
948 |
+
{
|
949 |
+
"cell_type": "code",
|
950 |
+
"execution_count": null,
|
951 |
+
"metadata": {
|
952 |
+
"gather": {
|
953 |
+
"logged": 1706486541918
|
954 |
+
}
|
955 |
+
},
|
956 |
+
"outputs": [],
|
957 |
+
"source": [
|
958 |
+
"wandb.finish()"
|
959 |
+
]
|
960 |
+
},
|
961 |
+
{
|
962 |
+
"cell_type": "code",
|
963 |
+
"execution_count": null,
|
964 |
+
"metadata": {
|
965 |
+
"gather": {
|
966 |
+
"logged": 1706486541928
|
967 |
+
}
|
968 |
+
},
|
969 |
+
"outputs": [],
|
970 |
+
"source": [
|
971 |
+
"variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
972 |
+
"tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
973 |
+
"tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
974 |
+
"sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
975 |
+
"\n",
|
976 |
+
"model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
977 |
+
" variant=variant,\n",
|
978 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
|
979 |
+
]
|
980 |
+
},
|
981 |
+
{
|
982 |
+
"cell_type": "code",
|
983 |
+
"execution_count": null,
|
984 |
+
"metadata": {},
|
985 |
+
"outputs": [],
|
986 |
+
"source": [
|
987 |
+
"variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
988 |
+
"tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
989 |
+
"tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
990 |
+
"sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
991 |
+
"\n",
|
992 |
+
"model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
993 |
+
" variant=variant,\n",
|
994 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,})\")"
|
995 |
+
]
|
996 |
+
}
|
997 |
+
],
|
998 |
+
"metadata": {
|
999 |
+
"datalore": {
|
1000 |
+
"base_environment": "default",
|
1001 |
+
"computation_mode": "JUPYTER",
|
1002 |
+
"package_manager": "pip",
|
1003 |
+
"packages": [
|
1004 |
+
{
|
1005 |
+
"name": "datasets",
|
1006 |
+
"source": "PIP",
|
1007 |
+
"version": "2.16.1"
|
1008 |
+
},
|
1009 |
+
{
|
1010 |
+
"name": "torch",
|
1011 |
+
"source": "PIP",
|
1012 |
+
"version": "2.1.2"
|
1013 |
+
},
|
1014 |
+
{
|
1015 |
+
"name": "accelerate",
|
1016 |
+
"source": "PIP",
|
1017 |
+
"version": "0.26.1"
|
1018 |
+
}
|
1019 |
+
],
|
1020 |
+
"report_row_ids": [
|
1021 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
1022 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
1023 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
1024 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
1025 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
1026 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
1027 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
1028 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
1029 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
1030 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
1031 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
1032 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
1033 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
1034 |
+
],
|
1035 |
+
"version": 3
|
1036 |
+
},
|
1037 |
+
"kernel_info": {
|
1038 |
+
"name": "python38-azureml-pt-tf"
|
1039 |
+
},
|
1040 |
+
"kernelspec": {
|
1041 |
+
"display_name": "azureml_py38_PT_TF",
|
1042 |
+
"language": "python",
|
1043 |
+
"name": "python3"
|
1044 |
+
},
|
1045 |
+
"language_info": {
|
1046 |
+
"codemirror_mode": {
|
1047 |
+
"name": "ipython",
|
1048 |
+
"version": 3
|
1049 |
+
},
|
1050 |
+
"file_extension": ".py",
|
1051 |
+
"mimetype": "text/x-python",
|
1052 |
+
"name": "python",
|
1053 |
+
"nbconvert_exporter": "python",
|
1054 |
+
"pygments_lexer": "ipython3",
|
1055 |
+
"version": "3.8.5"
|
1056 |
+
},
|
1057 |
+
"microsoft": {
|
1058 |
+
"host": {
|
1059 |
+
"AzureML": {
|
1060 |
+
"notebookHasBeenCompleted": true
|
1061 |
+
}
|
1062 |
+
},
|
1063 |
+
"ms_spell_check": {
|
1064 |
+
"ms_spell_check_language": "en"
|
1065 |
+
}
|
1066 |
+
},
|
1067 |
+
"nteract": {
|
1068 |
+
"version": "nteract-front-end@1.0.0"
|
1069 |
+
}
|
1070 |
+
},
|
1071 |
+
"nbformat": 4,
|
1072 |
+
"nbformat_minor": 4
|
1073 |
+
}
|
notebooks/.ipynb_aml_checkpoints/DAEDRA-checkpoint2024-0-30-21-44-8Z.ipynb
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {}
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"source": [
|
15 |
+
"%pip install accelerate -U"
|
16 |
+
],
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"output_type": "stream",
|
20 |
+
"name": "stdout",
|
21 |
+
"text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"execution_count": 1,
|
25 |
+
"metadata": {
|
26 |
+
"gather": {
|
27 |
+
"logged": 1706475754655
|
28 |
+
},
|
29 |
+
"nteract": {
|
30 |
+
"transient": {
|
31 |
+
"deleting": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"tags": []
|
35 |
+
}
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"source": [
|
40 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
41 |
+
],
|
42 |
+
"outputs": [
|
43 |
+
{
|
44 |
+
"output_type": "stream",
|
45 |
+
"name": "stdout",
|
46 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"execution_count": 2,
|
50 |
+
"metadata": {
|
51 |
+
"nteract": {
|
52 |
+
"transient": {
|
53 |
+
"deleting": false
|
54 |
+
}
|
55 |
+
}
|
56 |
+
}
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"import pandas as pd\n",
|
62 |
+
"import numpy as np\n",
|
63 |
+
"import torch\n",
|
64 |
+
"import os\n",
|
65 |
+
"from typing import List, Union\n",
|
66 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
|
67 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
68 |
+
"import shap\n",
|
69 |
+
"import wandb\n",
|
70 |
+
"import evaluate\n",
|
71 |
+
"import logging\n",
|
72 |
+
"\n",
|
73 |
+
"wandb.finish()\n",
|
74 |
+
"\n",
|
75 |
+
"\n",
|
76 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
77 |
+
"\n",
|
78 |
+
"%load_ext watermark"
|
79 |
+
],
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"output_type": "stream",
|
83 |
+
"name": "stderr",
|
84 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 17:46:15.020290: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 17:46:16.031641: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031779: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
85 |
+
}
|
86 |
+
],
|
87 |
+
"execution_count": 3,
|
88 |
+
"metadata": {
|
89 |
+
"datalore": {
|
90 |
+
"hide_input_from_viewers": false,
|
91 |
+
"hide_output_from_viewers": false,
|
92 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
93 |
+
"report_properties": {
|
94 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
95 |
+
},
|
96 |
+
"type": "CODE"
|
97 |
+
},
|
98 |
+
"gather": {
|
99 |
+
"logged": 1706550378660
|
100 |
+
},
|
101 |
+
"tags": []
|
102 |
+
}
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"source": [
|
107 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
108 |
+
"\n",
|
109 |
+
"SEED: int = 42\n",
|
110 |
+
"\n",
|
111 |
+
"BATCH_SIZE: int = 32\n",
|
112 |
+
"EPOCHS: int = 5\n",
|
113 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
114 |
+
"\n",
|
115 |
+
"# WandB configuration\n",
|
116 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
117 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
118 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
119 |
+
],
|
120 |
+
"outputs": [],
|
121 |
+
"execution_count": 4,
|
122 |
+
"metadata": {
|
123 |
+
"collapsed": false,
|
124 |
+
"gather": {
|
125 |
+
"logged": 1706550378812
|
126 |
+
},
|
127 |
+
"jupyter": {
|
128 |
+
"outputs_hidden": false
|
129 |
+
}
|
130 |
+
}
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"source": [
|
135 |
+
"%watermark --iversion"
|
136 |
+
],
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"output_type": "stream",
|
140 |
+
"name": "stdout",
|
141 |
+
"text": "shap : 0.44.1\npandas : 2.0.2\nwandb : 0.16.2\nre : 2.2.1\nevaluate: 0.4.1\ntorch : 1.12.0\nnumpy : 1.23.5\nlogging : 0.5.1.2\n\n"
|
142 |
+
}
|
143 |
+
],
|
144 |
+
"execution_count": 5,
|
145 |
+
"metadata": {
|
146 |
+
"collapsed": false,
|
147 |
+
"jupyter": {
|
148 |
+
"outputs_hidden": false
|
149 |
+
}
|
150 |
+
}
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"source": [
|
155 |
+
"!nvidia-smi"
|
156 |
+
],
|
157 |
+
"outputs": [
|
158 |
+
{
|
159 |
+
"output_type": "stream",
|
160 |
+
"name": "stdout",
|
161 |
+
"text": "Mon Jan 29 17:46:18 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
162 |
+
}
|
163 |
+
],
|
164 |
+
"execution_count": 6,
|
165 |
+
"metadata": {
|
166 |
+
"datalore": {
|
167 |
+
"hide_input_from_viewers": true,
|
168 |
+
"hide_output_from_viewers": true,
|
169 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
170 |
+
"type": "CODE"
|
171 |
+
}
|
172 |
+
}
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "markdown",
|
176 |
+
"source": [
|
177 |
+
"## Loading the data set"
|
178 |
+
],
|
179 |
+
"metadata": {
|
180 |
+
"datalore": {
|
181 |
+
"hide_input_from_viewers": false,
|
182 |
+
"hide_output_from_viewers": false,
|
183 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
184 |
+
"report_properties": {
|
185 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
186 |
+
},
|
187 |
+
"type": "MD"
|
188 |
+
}
|
189 |
+
}
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"source": [
|
194 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
195 |
+
],
|
196 |
+
"outputs": [],
|
197 |
+
"execution_count": 7,
|
198 |
+
"metadata": {
|
199 |
+
"collapsed": false,
|
200 |
+
"gather": {
|
201 |
+
"logged": 1706550381141
|
202 |
+
},
|
203 |
+
"jupyter": {
|
204 |
+
"outputs_hidden": false
|
205 |
+
}
|
206 |
+
}
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"source": [
|
211 |
+
"dataset"
|
212 |
+
],
|
213 |
+
"outputs": [
|
214 |
+
{
|
215 |
+
"output_type": "execute_result",
|
216 |
+
"execution_count": 8,
|
217 |
+
"data": {
|
218 |
+
"text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
|
219 |
+
},
|
220 |
+
"metadata": {}
|
221 |
+
}
|
222 |
+
],
|
223 |
+
"execution_count": 8,
|
224 |
+
"metadata": {
|
225 |
+
"collapsed": false,
|
226 |
+
"gather": {
|
227 |
+
"logged": 1706550381303
|
228 |
+
},
|
229 |
+
"jupyter": {
|
230 |
+
"outputs_hidden": false,
|
231 |
+
"source_hidden": false
|
232 |
+
},
|
233 |
+
"nteract": {
|
234 |
+
"transient": {
|
235 |
+
"deleting": false
|
236 |
+
}
|
237 |
+
}
|
238 |
+
}
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "code",
|
242 |
+
"source": [
|
243 |
+
"SUBSAMPLING = 0.01\n",
|
244 |
+
"\n",
|
245 |
+
"if SUBSAMPLING < 1:\n",
|
246 |
+
" _ = DatasetDict()\n",
|
247 |
+
" for each in dataset.keys():\n",
|
248 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
249 |
+
"\n",
|
250 |
+
" dataset = _"
|
251 |
+
],
|
252 |
+
"outputs": [],
|
253 |
+
"execution_count": 9,
|
254 |
+
"metadata": {
|
255 |
+
"gather": {
|
256 |
+
"logged": 1706550381472
|
257 |
+
}
|
258 |
+
}
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "markdown",
|
262 |
+
"source": [
|
263 |
+
"## Tokenisation and encoding"
|
264 |
+
],
|
265 |
+
"metadata": {}
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"source": [
|
270 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
271 |
+
" return ds_enc"
|
272 |
+
],
|
273 |
+
"outputs": [],
|
274 |
+
"execution_count": 10,
|
275 |
+
"metadata": {
|
276 |
+
"gather": {
|
277 |
+
"logged": 1706550381637
|
278 |
+
}
|
279 |
+
}
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "markdown",
|
283 |
+
"source": [
|
284 |
+
"## Evaluation metrics"
|
285 |
+
],
|
286 |
+
"metadata": {}
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"source": [
|
291 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
292 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
293 |
+
"f1 = evaluate.load(\"f1\")"
|
294 |
+
],
|
295 |
+
"outputs": [],
|
296 |
+
"execution_count": 11,
|
297 |
+
"metadata": {
|
298 |
+
"gather": {
|
299 |
+
"logged": 1706550381778
|
300 |
+
}
|
301 |
+
}
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"source": [
|
306 |
+
"def compute_metrics(eval_pred):\n",
|
307 |
+
" predictions, labels = eval_pred\n",
|
308 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
309 |
+
" return {\n",
|
310 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
311 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
312 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
313 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
314 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
315 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
316 |
+
" }"
|
317 |
+
],
|
318 |
+
"outputs": [],
|
319 |
+
"execution_count": 12,
|
320 |
+
"metadata": {
|
321 |
+
"gather": {
|
322 |
+
"logged": 1706550381891
|
323 |
+
}
|
324 |
+
}
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"source": [
|
329 |
+
"## Training"
|
330 |
+
],
|
331 |
+
"metadata": {}
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "markdown",
|
335 |
+
"source": [
|
336 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
337 |
+
],
|
338 |
+
"metadata": {}
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"source": [
|
343 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
344 |
+
],
|
345 |
+
"outputs": [],
|
346 |
+
"execution_count": 13,
|
347 |
+
"metadata": {
|
348 |
+
"gather": {
|
349 |
+
"logged": 1706550382032
|
350 |
+
}
|
351 |
+
}
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"source": [
|
356 |
+
"def train_from_model(model_ckpt: str, push: bool = False):\n",
|
357 |
+
" print(f\"Initialising training based on {model_ckpt}...\")\n",
|
358 |
+
"\n",
|
359 |
+
" print(\"Tokenising...\")\n",
|
360 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
361 |
+
"\n",
|
362 |
+
" cols = dataset[\"train\"].column_names\n",
|
363 |
+
" cols.remove(\"label\")\n",
|
364 |
+
" ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
|
365 |
+
"\n",
|
366 |
+
" print(\"Loading model...\")\n",
|
367 |
+
" try:\n",
|
368 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
369 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
370 |
+
" id2label=label_map, \n",
|
371 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
372 |
+
" except OSError:\n",
|
373 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
374 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
375 |
+
" id2label=label_map, \n",
|
376 |
+
" label2id={v:k for k,v in label_map.items()},\n",
|
377 |
+
" from_tf=True)\n",
|
378 |
+
"\n",
|
379 |
+
"\n",
|
380 |
+
" args = TrainingArguments(\n",
|
381 |
+
" output_dir=\"vaers\",\n",
|
382 |
+
" evaluation_strategy=\"epoch\",\n",
|
383 |
+
" save_strategy=\"epoch\",\n",
|
384 |
+
" learning_rate=2e-5,\n",
|
385 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
386 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
387 |
+
" num_train_epochs=EPOCHS,\n",
|
388 |
+
" weight_decay=.01,\n",
|
389 |
+
" logging_steps=1,\n",
|
390 |
+
" load_best_model_at_end=True,\n",
|
391 |
+
" run_name=f\"daedra-training\",\n",
|
392 |
+
" report_to=[\"wandb\"])\n",
|
393 |
+
"\n",
|
394 |
+
" trainer = Trainer(\n",
|
395 |
+
" model=model,\n",
|
396 |
+
" args=args,\n",
|
397 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
398 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
399 |
+
" tokenizer=tokenizer,\n",
|
400 |
+
" compute_metrics=compute_metrics)\n",
|
401 |
+
" \n",
|
402 |
+
" if SUBSAMPLING != 1.0:\n",
|
403 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
404 |
+
" else:\n",
|
405 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
406 |
+
"\n",
|
407 |
+
" wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
408 |
+
" wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
409 |
+
" \n",
|
410 |
+
" wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
|
411 |
+
"\n",
|
412 |
+
" print(\"Starting training...\")\n",
|
413 |
+
"\n",
|
414 |
+
" trainer.train()\n",
|
415 |
+
"\n",
|
416 |
+
" print(\"Training finished.\")\n",
|
417 |
+
"\n",
|
418 |
+
" if push:\n",
|
419 |
+
" variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
420 |
+
" tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
421 |
+
" tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
422 |
+
" sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
423 |
+
"\n",
|
424 |
+
" model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
425 |
+
" variant=variant,\n",
|
426 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
|
427 |
+
],
|
428 |
+
"outputs": [],
|
429 |
+
"execution_count": 14,
|
430 |
+
"metadata": {
|
431 |
+
"jupyter": {
|
432 |
+
"outputs_hidden": false,
|
433 |
+
"source_hidden": false
|
434 |
+
},
|
435 |
+
"nteract": {
|
436 |
+
"transient": {
|
437 |
+
"deleting": false
|
438 |
+
}
|
439 |
+
},
|
440 |
+
"gather": {
|
441 |
+
"logged": 1706550382160
|
442 |
+
}
|
443 |
+
}
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "code",
|
447 |
+
"source": [
|
448 |
+
"\n",
|
449 |
+
"base_models = [\n",
|
450 |
+
" \"bert-base-uncased\",\n",
|
451 |
+
" \"distilbert-base-uncased\",\n",
|
452 |
+
"]"
|
453 |
+
],
|
454 |
+
"outputs": [],
|
455 |
+
"execution_count": 15,
|
456 |
+
"metadata": {
|
457 |
+
"gather": {
|
458 |
+
"logged": 1706550382318
|
459 |
+
}
|
460 |
+
}
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"cell_type": "code",
|
464 |
+
"source": [
|
465 |
+
"BATCH_SIZE=1\n",
|
466 |
+
"\n",
|
467 |
+
"train_from_model(\"biobert/Bio_ClinicalBERT/\")"
|
468 |
+
],
|
469 |
+
"outputs": [
|
470 |
+
{
|
471 |
+
"output_type": "stream",
|
472 |
+
"name": "stdout",
|
473 |
+
"text": "Initialising training based on biobert/Bio_ClinicalBERT/...\nTokenising...\nLoading model...\n"
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"output_type": "stream",
|
477 |
+
"name": "stderr",
|
478 |
+
"text": "Map: 100%|██████████| 2722/2722 [00:01<00:00, 2195.12 examples/s]\nAll TF 2.0 model weights were used when initializing BertForSequenceClassification.\n\nAll the weights of BertForSequenceClassification were initialized from the TF 2.0 model.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.\n"
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"output_type": "display_data",
|
482 |
+
"data": {
|
483 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
484 |
+
"text/html": "Finishing last run (ID:sg022tqh) before initializing another..."
|
485 |
+
},
|
486 |
+
"metadata": {}
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"output_type": "display_data",
|
490 |
+
"data": {
|
491 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
492 |
+
"text/html": " View run <strong style=\"color:#cdcd00\">daedra_0.01-biobert/Bio_ClinicalBERT/</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
493 |
+
},
|
494 |
+
"metadata": {}
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"output_type": "display_data",
|
498 |
+
"data": {
|
499 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
500 |
+
"text/html": "Find logs at: <code>./wandb/run-20240129_174816-sg022tqh/logs</code>"
|
501 |
+
},
|
502 |
+
"metadata": {}
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"output_type": "display_data",
|
506 |
+
"data": {
|
507 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
508 |
+
"text/html": "Successfully finished last run (ID:sg022tqh). Initializing new run:<br/>"
|
509 |
+
},
|
510 |
+
"metadata": {}
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"output_type": "display_data",
|
514 |
+
"data": {
|
515 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
516 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
517 |
+
},
|
518 |
+
"metadata": {}
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"output_type": "display_data",
|
522 |
+
"data": {
|
523 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
524 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_174936-kilkkg1j</code>"
|
525 |
+
},
|
526 |
+
"metadata": {}
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"output_type": "display_data",
|
530 |
+
"data": {
|
531 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
532 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">daedra_0.01-biobert/Bio_ClinicalBERT/</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
533 |
+
},
|
534 |
+
"metadata": {}
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"output_type": "display_data",
|
538 |
+
"data": {
|
539 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
540 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
541 |
+
},
|
542 |
+
"metadata": {}
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"output_type": "display_data",
|
546 |
+
"data": {
|
547 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
548 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j</a>"
|
549 |
+
},
|
550 |
+
"metadata": {}
|
551 |
+
},
|
552 |
+
{
|
553 |
+
"output_type": "stream",
|
554 |
+
"name": "stdout",
|
555 |
+
"text": "Starting training...\n"
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"output_type": "stream",
|
559 |
+
"name": "stderr",
|
560 |
+
"text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"output_type": "display_data",
|
564 |
+
"data": {
|
565 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
566 |
+
"text/html": "\n <div>\n \n <progress value='1496' max='15880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1496/15880 07:43 < 1:14:19, 3.23 it/s, Epoch 0.47/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
|
567 |
+
},
|
568 |
+
"metadata": {}
|
569 |
+
}
|
570 |
+
],
|
571 |
+
"execution_count": 21,
|
572 |
+
"metadata": {
|
573 |
+
"gather": {
|
574 |
+
"logged": 1706551053473
|
575 |
+
}
|
576 |
+
}
|
577 |
+
},
|
578 |
+
{
|
579 |
+
"cell_type": "code",
|
580 |
+
"source": [],
|
581 |
+
"outputs": [],
|
582 |
+
"execution_count": null,
|
583 |
+
"metadata": {
|
584 |
+
"jupyter": {
|
585 |
+
"source_hidden": false,
|
586 |
+
"outputs_hidden": false
|
587 |
+
},
|
588 |
+
"nteract": {
|
589 |
+
"transient": {
|
590 |
+
"deleting": false
|
591 |
+
}
|
592 |
+
}
|
593 |
+
}
|
594 |
+
}
|
595 |
+
],
|
596 |
+
"metadata": {
|
597 |
+
"datalore": {
|
598 |
+
"base_environment": "default",
|
599 |
+
"computation_mode": "JUPYTER",
|
600 |
+
"package_manager": "pip",
|
601 |
+
"packages": [
|
602 |
+
{
|
603 |
+
"name": "datasets",
|
604 |
+
"source": "PIP",
|
605 |
+
"version": "2.16.1"
|
606 |
+
},
|
607 |
+
{
|
608 |
+
"name": "torch",
|
609 |
+
"source": "PIP",
|
610 |
+
"version": "2.1.2"
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"name": "accelerate",
|
614 |
+
"source": "PIP",
|
615 |
+
"version": "0.26.1"
|
616 |
+
}
|
617 |
+
],
|
618 |
+
"report_row_ids": [
|
619 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
620 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
621 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
622 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
623 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
624 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
625 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
626 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
627 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
628 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
629 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
630 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
631 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
632 |
+
],
|
633 |
+
"version": 3
|
634 |
+
},
|
635 |
+
"kernel_info": {
|
636 |
+
"name": "python38-azureml-pt-tf"
|
637 |
+
},
|
638 |
+
"kernelspec": {
|
639 |
+
"display_name": "azureml_py38_PT_TF",
|
640 |
+
"language": "python",
|
641 |
+
"name": "python3"
|
642 |
+
},
|
643 |
+
"language_info": {
|
644 |
+
"name": "python",
|
645 |
+
"version": "3.8.5",
|
646 |
+
"mimetype": "text/x-python",
|
647 |
+
"codemirror_mode": {
|
648 |
+
"name": "ipython",
|
649 |
+
"version": 3
|
650 |
+
},
|
651 |
+
"pygments_lexer": "ipython3",
|
652 |
+
"nbconvert_exporter": "python",
|
653 |
+
"file_extension": ".py"
|
654 |
+
},
|
655 |
+
"microsoft": {
|
656 |
+
"host": {
|
657 |
+
"AzureML": {
|
658 |
+
"notebookHasBeenCompleted": true
|
659 |
+
}
|
660 |
+
},
|
661 |
+
"ms_spell_check": {
|
662 |
+
"ms_spell_check_language": "en"
|
663 |
+
}
|
664 |
+
},
|
665 |
+
"nteract": {
|
666 |
+
"version": "nteract-front-end@1.0.0"
|
667 |
+
}
|
668 |
+
},
|
669 |
+
"nbformat": 4,
|
670 |
+
"nbformat_minor": 4
|
671 |
+
}
|
notebooks/.ipynb_aml_checkpoints/microsample_model_comparison-checkpoint2024-0-31-14-6-22Z.ipynb
ADDED
File without changes
|
notebooks/DAEDRA-Copy1.ipynb
ADDED
@@ -0,0 +1,1634 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"metadata": {},
|
6 |
+
"source": [
|
7 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
8 |
+
"\n",
|
9 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
10 |
+
]
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"execution_count": 1,
|
15 |
+
"metadata": {
|
16 |
+
"nteract": {
|
17 |
+
"transient": {
|
18 |
+
"deleting": false
|
19 |
+
}
|
20 |
+
},
|
21 |
+
"tags": []
|
22 |
+
},
|
23 |
+
"outputs": [
|
24 |
+
{
|
25 |
+
"name": "stdout",
|
26 |
+
"output_type": "stream",
|
27 |
+
"text": [
|
28 |
+
"Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\n",
|
29 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\n",
|
30 |
+
"Requirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\n",
|
31 |
+
"Requirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\n",
|
32 |
+
"Requirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\n",
|
33 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\n",
|
34 |
+
"Requirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\n",
|
35 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\n",
|
36 |
+
"Requirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\n",
|
37 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\n",
|
38 |
+
"Requirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\n",
|
39 |
+
"Requirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\n",
|
40 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\n",
|
41 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\n",
|
42 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\n",
|
43 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\n",
|
44 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\n",
|
45 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
46 |
+
]
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"source": [
|
50 |
+
"# %pip install accelerate -U"
|
51 |
+
]
|
52 |
+
},
|
53 |
+
{
|
54 |
+
"cell_type": "code",
|
55 |
+
"execution_count": 2,
|
56 |
+
"metadata": {
|
57 |
+
"collapsed": true,
|
58 |
+
"jupyter": {
|
59 |
+
"outputs_hidden": true,
|
60 |
+
"source_hidden": false
|
61 |
+
},
|
62 |
+
"nteract": {
|
63 |
+
"transient": {
|
64 |
+
"deleting": false
|
65 |
+
}
|
66 |
+
}
|
67 |
+
},
|
68 |
+
"outputs": [
|
69 |
+
{
|
70 |
+
"name": "stdout",
|
71 |
+
"output_type": "stream",
|
72 |
+
"text": [
|
73 |
+
"Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\n",
|
74 |
+
"Requirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\n",
|
75 |
+
"Requirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\n",
|
76 |
+
"Requirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\n",
|
77 |
+
"Requirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\n",
|
78 |
+
"Requirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\n",
|
79 |
+
"Requirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\n",
|
80 |
+
"Requirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\n",
|
81 |
+
"Requirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\n",
|
82 |
+
"Requirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\n",
|
83 |
+
"Requirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\n",
|
84 |
+
"Requirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\n",
|
85 |
+
"Requirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\n",
|
86 |
+
"Requirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\n",
|
87 |
+
"Requirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\n",
|
88 |
+
"Requirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\n",
|
89 |
+
"Requirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\n",
|
90 |
+
"Requirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\n",
|
91 |
+
"Requirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\n",
|
92 |
+
"Requirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\n",
|
93 |
+
"Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\n",
|
94 |
+
"Requirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\n",
|
95 |
+
"Requirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\n",
|
96 |
+
"Requirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\n",
|
97 |
+
"Requirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\n",
|
98 |
+
"Requirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\n",
|
99 |
+
"Requirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\n",
|
100 |
+
"Requirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\n",
|
101 |
+
"Requirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\n",
|
102 |
+
"Requirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\n",
|
103 |
+
"Requirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\n",
|
104 |
+
"Requirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\n",
|
105 |
+
"Requirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\n",
|
106 |
+
"Requirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\n",
|
107 |
+
"Requirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\n",
|
108 |
+
"Requirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\n",
|
109 |
+
"Requirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\n",
|
110 |
+
"Requirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\n",
|
111 |
+
"Requirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\n",
|
112 |
+
"Requirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\n",
|
113 |
+
"Requirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n",
|
114 |
+
"Requirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\n",
|
115 |
+
"Requirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\n",
|
116 |
+
"Requirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\n",
|
117 |
+
"Requirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\n",
|
118 |
+
"Requirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\n",
|
119 |
+
"Requirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\n",
|
120 |
+
"Requirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\n",
|
121 |
+
"Requirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\n",
|
122 |
+
"Requirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\n",
|
123 |
+
"Requirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\n",
|
124 |
+
"Requirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\n",
|
125 |
+
"Requirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\n",
|
126 |
+
"Requirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\n",
|
127 |
+
"Requirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\n",
|
128 |
+
"Requirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\n",
|
129 |
+
"Requirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\n",
|
130 |
+
"Requirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\n",
|
131 |
+
"Requirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\n",
|
132 |
+
"Requirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\n",
|
133 |
+
"Requirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\n",
|
134 |
+
"Requirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\n",
|
135 |
+
"Requirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\n",
|
136 |
+
"Requirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\n",
|
137 |
+
"Requirement already satisfied: python-dateutil>=2.8.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2.8.2)\n",
|
138 |
+
"Requirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
139 |
+
"Requirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\n",
|
140 |
+
"Requirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\n",
|
141 |
+
"Requirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\n",
|
142 |
+
"Requirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\n",
|
143 |
+
"Requirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\n",
|
144 |
+
"Requirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\n",
|
145 |
+
"Requirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\n",
|
146 |
+
"Requirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\n",
|
147 |
+
"Requirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\n",
|
148 |
+
"Requirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\n",
|
149 |
+
"Note: you may need to restart the kernel to use updated packages.\n"
|
150 |
+
]
|
151 |
+
}
|
152 |
+
],
|
153 |
+
"source": [
|
154 |
+
"# %pip install transformers datasets shap watermark wandb"
|
155 |
+
]
|
156 |
+
},
|
157 |
+
{
|
158 |
+
"cell_type": "code",
|
159 |
+
"execution_count": 66,
|
160 |
+
"metadata": {
|
161 |
+
"datalore": {
|
162 |
+
"hide_input_from_viewers": false,
|
163 |
+
"hide_output_from_viewers": false,
|
164 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
165 |
+
"report_properties": {
|
166 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
167 |
+
},
|
168 |
+
"type": "CODE"
|
169 |
+
},
|
170 |
+
"gather": {
|
171 |
+
"logged": 1706449625034
|
172 |
+
},
|
173 |
+
"tags": []
|
174 |
+
},
|
175 |
+
"outputs": [
|
176 |
+
{
|
177 |
+
"name": "stdout",
|
178 |
+
"output_type": "stream",
|
179 |
+
"text": [
|
180 |
+
"The watermark extension is already loaded. To reload it, use:\n",
|
181 |
+
" %reload_ext watermark\n"
|
182 |
+
]
|
183 |
+
}
|
184 |
+
],
|
185 |
+
"source": [
|
186 |
+
"import pandas as pd\n",
|
187 |
+
"import numpy as np\n",
|
188 |
+
"import torch\n",
|
189 |
+
"import os\n",
|
190 |
+
"from typing import List\n",
|
191 |
+
"from sklearn.metrics import f1_score, accuracy_score, classification_report\n",
|
192 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, pipeline\n",
|
193 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
194 |
+
"from pyarrow import Table\n",
|
195 |
+
"import shap\n",
|
196 |
+
"import wandb\n",
|
197 |
+
"\n",
|
198 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
199 |
+
"\n",
|
200 |
+
"%load_ext watermark"
|
201 |
+
]
|
202 |
+
},
|
203 |
+
{
|
204 |
+
"cell_type": "code",
|
205 |
+
"execution_count": 43,
|
206 |
+
"metadata": {
|
207 |
+
"collapsed": false,
|
208 |
+
"gather": {
|
209 |
+
"logged": 1706449721319
|
210 |
+
},
|
211 |
+
"jupyter": {
|
212 |
+
"outputs_hidden": false
|
213 |
+
}
|
214 |
+
},
|
215 |
+
"outputs": [],
|
216 |
+
"source": [
|
217 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
218 |
+
"\n",
|
219 |
+
"SEED: int = 42\n",
|
220 |
+
"\n",
|
221 |
+
"BATCH_SIZE: int = 32\n",
|
222 |
+
"EPOCHS: int = 3\n",
|
223 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
224 |
+
"\n",
|
225 |
+
"CLASS_NAMES: List[str] = [\"DIED\",\n",
|
226 |
+
" \"ER_VISIT\",\n",
|
227 |
+
" \"HOSPITAL\",\n",
|
228 |
+
" \"OFC_VISIT\",\n",
|
229 |
+
" #\"X_STAY\", # pruned\n",
|
230 |
+
" #\"DISABLE\", # pruned\n",
|
231 |
+
" #\"D_PRESENTED\" # pruned\n",
|
232 |
+
" ]\n",
|
233 |
+
"\n",
|
234 |
+
"\n",
|
235 |
+
"\n",
|
236 |
+
"\n",
|
237 |
+
"# WandB configuration\n",
|
238 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA model training\" # name your W&B project\n",
|
239 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
240 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
241 |
+
]
|
242 |
+
},
|
243 |
+
{
|
244 |
+
"cell_type": "code",
|
245 |
+
"execution_count": 44,
|
246 |
+
"metadata": {
|
247 |
+
"collapsed": false,
|
248 |
+
"jupyter": {
|
249 |
+
"outputs_hidden": false
|
250 |
+
}
|
251 |
+
},
|
252 |
+
"outputs": [
|
253 |
+
{
|
254 |
+
"name": "stdout",
|
255 |
+
"output_type": "stream",
|
256 |
+
"text": [
|
257 |
+
"shap : 0.44.1\n",
|
258 |
+
"torch : 1.12.0\n",
|
259 |
+
"logging: 0.5.1.2\n",
|
260 |
+
"numpy : 1.23.5\n",
|
261 |
+
"pandas : 2.0.2\n",
|
262 |
+
"re : 2.2.1\n",
|
263 |
+
"\n"
|
264 |
+
]
|
265 |
+
}
|
266 |
+
],
|
267 |
+
"source": [
|
268 |
+
"%watermark --iversion"
|
269 |
+
]
|
270 |
+
},
|
271 |
+
{
|
272 |
+
"cell_type": "code",
|
273 |
+
"execution_count": 45,
|
274 |
+
"metadata": {
|
275 |
+
"datalore": {
|
276 |
+
"hide_input_from_viewers": true,
|
277 |
+
"hide_output_from_viewers": true,
|
278 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
279 |
+
"type": "CODE"
|
280 |
+
}
|
281 |
+
},
|
282 |
+
"outputs": [
|
283 |
+
{
|
284 |
+
"name": "stdout",
|
285 |
+
"output_type": "stream",
|
286 |
+
"text": [
|
287 |
+
"Sun Jan 28 13:54:22 2024 \n",
|
288 |
+
"+---------------------------------------------------------------------------------------+\n",
|
289 |
+
"| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\n",
|
290 |
+
"|-----------------------------------------+----------------------+----------------------+\n",
|
291 |
+
"| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\n",
|
292 |
+
"| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\n",
|
293 |
+
"| | | MIG M. |\n",
|
294 |
+
"|=========================================+======================+======================|\n",
|
295 |
+
"| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\n",
|
296 |
+
"| N/A 30C P0 38W / 250W | 12830MiB / 16384MiB | 0% Default |\n",
|
297 |
+
"| | | N/A |\n",
|
298 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
299 |
+
"| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\n",
|
300 |
+
"| N/A 30C P0 38W / 250W | 11960MiB / 16384MiB | 0% Default |\n",
|
301 |
+
"| | | N/A |\n",
|
302 |
+
"+-----------------------------------------+----------------------+----------------------+\n",
|
303 |
+
" \n",
|
304 |
+
"+---------------------------------------------------------------------------------------+\n",
|
305 |
+
"| Processes: |\n",
|
306 |
+
"| GPU GI CI PID Type Process name GPU Memory |\n",
|
307 |
+
"| ID ID Usage |\n",
|
308 |
+
"|=======================================================================================|\n",
|
309 |
+
"| 0 N/A N/A 11781 C .../envs/azureml_py38_PT_TF/bin/python 12826MiB |\n",
|
310 |
+
"| 1 N/A N/A 11781 C .../envs/azureml_py38_PT_TF/bin/python 11956MiB |\n",
|
311 |
+
"+---------------------------------------------------------------------------------------+\n"
|
312 |
+
]
|
313 |
+
}
|
314 |
+
],
|
315 |
+
"source": [
|
316 |
+
"!nvidia-smi"
|
317 |
+
]
|
318 |
+
},
|
319 |
+
{
|
320 |
+
"cell_type": "markdown",
|
321 |
+
"metadata": {
|
322 |
+
"datalore": {
|
323 |
+
"hide_input_from_viewers": false,
|
324 |
+
"hide_output_from_viewers": false,
|
325 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
326 |
+
"report_properties": {
|
327 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
328 |
+
},
|
329 |
+
"type": "MD"
|
330 |
+
}
|
331 |
+
},
|
332 |
+
"source": [
|
333 |
+
"## Loading the data set"
|
334 |
+
]
|
335 |
+
},
|
336 |
+
{
|
337 |
+
"cell_type": "code",
|
338 |
+
"execution_count": 46,
|
339 |
+
"metadata": {
|
340 |
+
"collapsed": false,
|
341 |
+
"gather": {
|
342 |
+
"logged": 1706449040507
|
343 |
+
},
|
344 |
+
"jupyter": {
|
345 |
+
"outputs_hidden": false
|
346 |
+
}
|
347 |
+
},
|
348 |
+
"outputs": [],
|
349 |
+
"source": [
|
350 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
351 |
+
]
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"execution_count": 47,
|
356 |
+
"metadata": {
|
357 |
+
"collapsed": false,
|
358 |
+
"gather": {
|
359 |
+
"logged": 1706449044205
|
360 |
+
},
|
361 |
+
"jupyter": {
|
362 |
+
"outputs_hidden": false,
|
363 |
+
"source_hidden": false
|
364 |
+
},
|
365 |
+
"nteract": {
|
366 |
+
"transient": {
|
367 |
+
"deleting": false
|
368 |
+
}
|
369 |
+
}
|
370 |
+
},
|
371 |
+
"outputs": [
|
372 |
+
{
|
373 |
+
"data": {
|
374 |
+
"text/plain": [
|
375 |
+
"DatasetDict({\n",
|
376 |
+
" train: Dataset({\n",
|
377 |
+
" features: ['id', 'text', 'labels'],\n",
|
378 |
+
" num_rows: 1270444\n",
|
379 |
+
" })\n",
|
380 |
+
" test: Dataset({\n",
|
381 |
+
" features: ['id', 'text', 'labels'],\n",
|
382 |
+
" num_rows: 272238\n",
|
383 |
+
" })\n",
|
384 |
+
" val: Dataset({\n",
|
385 |
+
" features: ['id', 'text', 'labels'],\n",
|
386 |
+
" num_rows: 272238\n",
|
387 |
+
" })\n",
|
388 |
+
"})"
|
389 |
+
]
|
390 |
+
},
|
391 |
+
"execution_count": 47,
|
392 |
+
"metadata": {},
|
393 |
+
"output_type": "execute_result"
|
394 |
+
}
|
395 |
+
],
|
396 |
+
"source": [
|
397 |
+
"dataset"
|
398 |
+
]
|
399 |
+
},
|
400 |
+
{
|
401 |
+
"cell_type": "code",
|
402 |
+
"execution_count": 70,
|
403 |
+
"metadata": {},
|
404 |
+
"outputs": [],
|
405 |
+
"source": [
|
406 |
+
"SUBSAMPLING: float = 0.1"
|
407 |
+
]
|
408 |
+
},
|
409 |
+
{
|
410 |
+
"cell_type": "code",
|
411 |
+
"execution_count": 48,
|
412 |
+
"metadata": {
|
413 |
+
"collapsed": false,
|
414 |
+
"gather": {
|
415 |
+
"logged": 1706449378281
|
416 |
+
},
|
417 |
+
"jupyter": {
|
418 |
+
"outputs_hidden": false,
|
419 |
+
"source_hidden": false
|
420 |
+
},
|
421 |
+
"nteract": {
|
422 |
+
"transient": {
|
423 |
+
"deleting": false
|
424 |
+
}
|
425 |
+
}
|
426 |
+
},
|
427 |
+
"outputs": [],
|
428 |
+
"source": [
|
429 |
+
"def minisample(ds: DatasetDict, fraction: float) -> DatasetDict:\n",
|
430 |
+
" res = DatasetDict()\n",
|
431 |
+
"\n",
|
432 |
+
" res[\"train\"] = Dataset.from_dict(ds[\"train\"].shuffle()[:round(len(ds[\"train\"]) * fraction)])\n",
|
433 |
+
" res[\"test\"] = Dataset.from_dict(ds[\"test\"].shuffle()[:round(len(ds[\"test\"]) * fraction)])\n",
|
434 |
+
" res[\"val\"] = Dataset.from_dict(ds[\"val\"].shuffle()[:round(len(ds[\"val\"]) * fraction)])\n",
|
435 |
+
" \n",
|
436 |
+
" return res"
|
437 |
+
]
|
438 |
+
},
|
439 |
+
{
|
440 |
+
"cell_type": "code",
|
441 |
+
"execution_count": 49,
|
442 |
+
"metadata": {
|
443 |
+
"collapsed": false,
|
444 |
+
"gather": {
|
445 |
+
"logged": 1706449384162
|
446 |
+
},
|
447 |
+
"jupyter": {
|
448 |
+
"outputs_hidden": false,
|
449 |
+
"source_hidden": false
|
450 |
+
},
|
451 |
+
"nteract": {
|
452 |
+
"transient": {
|
453 |
+
"deleting": false
|
454 |
+
}
|
455 |
+
}
|
456 |
+
},
|
457 |
+
"outputs": [],
|
458 |
+
"source": [
|
459 |
+
"dataset = minisample(dataset, SUBSAMPLING)"
|
460 |
+
]
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"cell_type": "code",
|
464 |
+
"execution_count": 50,
|
465 |
+
"metadata": {
|
466 |
+
"collapsed": false,
|
467 |
+
"gather": {
|
468 |
+
"logged": 1706449387981
|
469 |
+
},
|
470 |
+
"jupyter": {
|
471 |
+
"outputs_hidden": false,
|
472 |
+
"source_hidden": false
|
473 |
+
},
|
474 |
+
"nteract": {
|
475 |
+
"transient": {
|
476 |
+
"deleting": false
|
477 |
+
}
|
478 |
+
}
|
479 |
+
},
|
480 |
+
"outputs": [
|
481 |
+
{
|
482 |
+
"data": {
|
483 |
+
"text/plain": [
|
484 |
+
"DatasetDict({\n",
|
485 |
+
" train: Dataset({\n",
|
486 |
+
" features: ['id', 'text', 'labels'],\n",
|
487 |
+
" num_rows: 127044\n",
|
488 |
+
" })\n",
|
489 |
+
" test: Dataset({\n",
|
490 |
+
" features: ['id', 'text', 'labels'],\n",
|
491 |
+
" num_rows: 27224\n",
|
492 |
+
" })\n",
|
493 |
+
" val: Dataset({\n",
|
494 |
+
" features: ['id', 'text', 'labels'],\n",
|
495 |
+
" num_rows: 27224\n",
|
496 |
+
" })\n",
|
497 |
+
"})"
|
498 |
+
]
|
499 |
+
},
|
500 |
+
"execution_count": 50,
|
501 |
+
"metadata": {},
|
502 |
+
"output_type": "execute_result"
|
503 |
+
}
|
504 |
+
],
|
505 |
+
"source": [
|
506 |
+
"dataset"
|
507 |
+
]
|
508 |
+
},
|
509 |
+
{
|
510 |
+
"cell_type": "markdown",
|
511 |
+
"metadata": {
|
512 |
+
"nteract": {
|
513 |
+
"transient": {
|
514 |
+
"deleting": false
|
515 |
+
}
|
516 |
+
}
|
517 |
+
},
|
518 |
+
"source": [
|
519 |
+
"We prune things down to the first four keys: `DIED`, `ER_VISIT`, `HOSPITAL`, `OFC_VISIT`."
|
520 |
+
]
|
521 |
+
},
|
522 |
+
{
|
523 |
+
"cell_type": "code",
|
524 |
+
"execution_count": 51,
|
525 |
+
"metadata": {
|
526 |
+
"collapsed": false,
|
527 |
+
"gather": {
|
528 |
+
"logged": 1706449443055
|
529 |
+
},
|
530 |
+
"jupyter": {
|
531 |
+
"outputs_hidden": false,
|
532 |
+
"source_hidden": false
|
533 |
+
},
|
534 |
+
"nteract": {
|
535 |
+
"transient": {
|
536 |
+
"deleting": false
|
537 |
+
}
|
538 |
+
}
|
539 |
+
},
|
540 |
+
"outputs": [],
|
541 |
+
"source": [
|
542 |
+
"ds = DatasetDict()\n",
|
543 |
+
"\n",
|
544 |
+
"for i in [\"test\", \"train\", \"val\"]:\n",
|
545 |
+
" tab = Table.from_arrays([dataset[i][\"id\"], dataset[i][\"text\"], [i[:4] for i in dataset[i][\"labels\"]]], names=[\"id\", \"text\", \"labels\"])\n",
|
546 |
+
" ds[i] = Dataset(tab)\n",
|
547 |
+
"\n",
|
548 |
+
"dataset = ds"
|
549 |
+
]
|
550 |
+
},
|
551 |
+
{
|
552 |
+
"cell_type": "markdown",
|
553 |
+
"metadata": {},
|
554 |
+
"source": [
|
555 |
+
"### Tokenisation and encoding"
|
556 |
+
]
|
557 |
+
},
|
558 |
+
{
|
559 |
+
"cell_type": "code",
|
560 |
+
"execution_count": 52,
|
561 |
+
"metadata": {
|
562 |
+
"datalore": {
|
563 |
+
"hide_input_from_viewers": true,
|
564 |
+
"hide_output_from_viewers": true,
|
565 |
+
"node_id": "I7n646PIscsUZRoHu6m7zm",
|
566 |
+
"type": "CODE"
|
567 |
+
},
|
568 |
+
"gather": {
|
569 |
+
"logged": 1706449638377
|
570 |
+
}
|
571 |
+
},
|
572 |
+
"outputs": [],
|
573 |
+
"source": [
|
574 |
+
"tokenizer = AutoTokenizer.from_pretrained(model_ckpt)"
|
575 |
+
]
|
576 |
+
},
|
577 |
+
{
|
578 |
+
"cell_type": "code",
|
579 |
+
"execution_count": 53,
|
580 |
+
"metadata": {
|
581 |
+
"datalore": {
|
582 |
+
"hide_input_from_viewers": true,
|
583 |
+
"hide_output_from_viewers": true,
|
584 |
+
"node_id": "QBLOSI0yVIslV7v7qX9ZC3",
|
585 |
+
"type": "CODE"
|
586 |
+
},
|
587 |
+
"gather": {
|
588 |
+
"logged": 1706449642580
|
589 |
+
}
|
590 |
+
},
|
591 |
+
"outputs": [],
|
592 |
+
"source": [
|
593 |
+
"def tokenize_and_encode(examples):\n",
|
594 |
+
" return tokenizer(examples[\"text\"], truncation=True)"
|
595 |
+
]
|
596 |
+
},
|
597 |
+
{
|
598 |
+
"cell_type": "code",
|
599 |
+
"execution_count": 54,
|
600 |
+
"metadata": {
|
601 |
+
"datalore": {
|
602 |
+
"hide_input_from_viewers": true,
|
603 |
+
"hide_output_from_viewers": true,
|
604 |
+
"node_id": "slHeNysZOX9uWS9PB7jFDb",
|
605 |
+
"type": "CODE"
|
606 |
+
},
|
607 |
+
"gather": {
|
608 |
+
"logged": 1706449721161
|
609 |
+
}
|
610 |
+
},
|
611 |
+
"outputs": [
|
612 |
+
{
|
613 |
+
"name": "stderr",
|
614 |
+
"output_type": "stream",
|
615 |
+
"text": [
|
616 |
+
"Map: 100%|██████████| 27224/27224 [00:11<00:00, 2347.91 examples/s]\n",
|
617 |
+
"Map: 100%|██████████| 127044/127044 [00:52<00:00, 2417.41 examples/s]\n",
|
618 |
+
"Map: 100%|██████████| 27224/27224 [00:11<00:00, 2376.02 examples/s]\n"
|
619 |
+
]
|
620 |
+
}
|
621 |
+
],
|
622 |
+
"source": [
|
623 |
+
"cols = dataset[\"train\"].column_names\n",
|
624 |
+
"cols.remove(\"labels\")\n",
|
625 |
+
"ds_enc = dataset.map(tokenize_and_encode, batched=True, remove_columns=cols)"
|
626 |
+
]
|
627 |
+
},
|
628 |
+
{
|
629 |
+
"cell_type": "markdown",
|
630 |
+
"metadata": {},
|
631 |
+
"source": [
|
632 |
+
"### Training"
|
633 |
+
]
|
634 |
+
},
|
635 |
+
{
|
636 |
+
"cell_type": "code",
|
637 |
+
"execution_count": 55,
|
638 |
+
"metadata": {
|
639 |
+
"datalore": {
|
640 |
+
"hide_input_from_viewers": true,
|
641 |
+
"hide_output_from_viewers": true,
|
642 |
+
"node_id": "itXWkbDw9sqbkMuDP84QoT",
|
643 |
+
"type": "CODE"
|
644 |
+
},
|
645 |
+
"gather": {
|
646 |
+
"logged": 1706449743072
|
647 |
+
}
|
648 |
+
},
|
649 |
+
"outputs": [],
|
650 |
+
"source": [
|
651 |
+
"class MultiLabelTrainer(Trainer):\n",
|
652 |
+
" def compute_loss(self, model, inputs, return_outputs=False):\n",
|
653 |
+
" labels = inputs.pop(\"labels\")\n",
|
654 |
+
" outputs = model(**inputs)\n",
|
655 |
+
" logits = outputs.logits\n",
|
656 |
+
" loss_fct = torch.nn.BCEWithLogitsLoss()\n",
|
657 |
+
" loss = loss_fct(logits.view(-1, self.model.config.num_labels),\n",
|
658 |
+
" labels.float().view(-1, self.model.config.num_labels))\n",
|
659 |
+
" return (loss, outputs) if return_outputs else loss"
|
660 |
+
]
|
661 |
+
},
|
662 |
+
{
|
663 |
+
"cell_type": "code",
|
664 |
+
"execution_count": 56,
|
665 |
+
"metadata": {
|
666 |
+
"datalore": {
|
667 |
+
"hide_input_from_viewers": true,
|
668 |
+
"hide_output_from_viewers": true,
|
669 |
+
"node_id": "ZQU7aW6TV45VmhHOQRzcnF",
|
670 |
+
"type": "CODE"
|
671 |
+
},
|
672 |
+
"gather": {
|
673 |
+
"logged": 1706449761205
|
674 |
+
}
|
675 |
+
},
|
676 |
+
"outputs": [
|
677 |
+
{
|
678 |
+
"name": "stderr",
|
679 |
+
"output_type": "stream",
|
680 |
+
"text": [
|
681 |
+
"Some weights of DistilBertForSequenceClassification were not initialized from the model checkpoint at distilbert-base-uncased and are newly initialized: ['classifier.bias', 'classifier.weight', 'pre_classifier.bias', 'pre_classifier.weight']\n",
|
682 |
+
"You should probably TRAIN this model on a down-stream task to be able to use it for predictions and inference.\n"
|
683 |
+
]
|
684 |
+
}
|
685 |
+
],
|
686 |
+
"source": [
|
687 |
+
"model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, num_labels=len(CLASS_NAMES)).to(\"cuda\")"
|
688 |
+
]
|
689 |
+
},
|
690 |
+
{
|
691 |
+
"cell_type": "code",
|
692 |
+
"execution_count": 57,
|
693 |
+
"metadata": {
|
694 |
+
"datalore": {
|
695 |
+
"hide_input_from_viewers": true,
|
696 |
+
"hide_output_from_viewers": true,
|
697 |
+
"node_id": "swhgyyyxoGL8HjnXJtMuSW",
|
698 |
+
"type": "CODE"
|
699 |
+
},
|
700 |
+
"gather": {
|
701 |
+
"logged": 1706449761541
|
702 |
+
}
|
703 |
+
},
|
704 |
+
"outputs": [],
|
705 |
+
"source": [
|
706 |
+
"def accuracy_threshold(y_pred, y_true, threshold=.5, sigmoid=True):\n",
|
707 |
+
" y_pred = torch.from_numpy(y_pred)\n",
|
708 |
+
" y_true = torch.from_numpy(y_true)\n",
|
709 |
+
"\n",
|
710 |
+
" if sigmoid:\n",
|
711 |
+
" y_pred = y_pred.sigmoid()\n",
|
712 |
+
"\n",
|
713 |
+
" return ((y_pred > threshold) == y_true.bool()).float().mean().item()"
|
714 |
+
]
|
715 |
+
},
|
716 |
+
{
|
717 |
+
"cell_type": "code",
|
718 |
+
"execution_count": 58,
|
719 |
+
"metadata": {
|
720 |
+
"datalore": {
|
721 |
+
"hide_input_from_viewers": true,
|
722 |
+
"hide_output_from_viewers": true,
|
723 |
+
"node_id": "1Uq3HtkaBxtHNAnSwit5cI",
|
724 |
+
"type": "CODE"
|
725 |
+
},
|
726 |
+
"gather": {
|
727 |
+
"logged": 1706449761720
|
728 |
+
}
|
729 |
+
},
|
730 |
+
"outputs": [],
|
731 |
+
"source": [
|
732 |
+
"def compute_metrics(eval_pred):\n",
|
733 |
+
" predictions, labels = eval_pred\n",
|
734 |
+
" return {'accuracy_thresh': accuracy_threshold(predictions, labels)}"
|
735 |
+
]
|
736 |
+
},
|
737 |
+
{
|
738 |
+
"cell_type": "code",
|
739 |
+
"execution_count": 63,
|
740 |
+
"metadata": {
|
741 |
+
"datalore": {
|
742 |
+
"hide_input_from_viewers": true,
|
743 |
+
"hide_output_from_viewers": true,
|
744 |
+
"node_id": "1iPZOTKPwSkTgX5dORqT89",
|
745 |
+
"type": "CODE"
|
746 |
+
},
|
747 |
+
"gather": {
|
748 |
+
"logged": 1706449761893
|
749 |
+
}
|
750 |
+
},
|
751 |
+
"outputs": [],
|
752 |
+
"source": [
|
753 |
+
"args = TrainingArguments(\n",
|
754 |
+
" output_dir=\"vaers\",\n",
|
755 |
+
" evaluation_strategy=\"epoch\",\n",
|
756 |
+
" learning_rate=2e-5,\n",
|
757 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
758 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
759 |
+
" num_train_epochs=EPOCHS,\n",
|
760 |
+
" weight_decay=.01,\n",
|
761 |
+
" logging_steps=1,\n",
|
762 |
+
" run_name=f\"daedra-training\",\n",
|
763 |
+
" report_to=[\"wandb\"]\n",
|
764 |
+
")"
|
765 |
+
]
|
766 |
+
},
|
767 |
+
{
|
768 |
+
"cell_type": "code",
|
769 |
+
"execution_count": 64,
|
770 |
+
"metadata": {
|
771 |
+
"datalore": {
|
772 |
+
"hide_input_from_viewers": true,
|
773 |
+
"hide_output_from_viewers": true,
|
774 |
+
"node_id": "bnRkNvRYltLun6gCEgL7v0",
|
775 |
+
"type": "CODE"
|
776 |
+
},
|
777 |
+
"gather": {
|
778 |
+
"logged": 1706449769103
|
779 |
+
}
|
780 |
+
},
|
781 |
+
"outputs": [],
|
782 |
+
"source": [
|
783 |
+
"multi_label_trainer = MultiLabelTrainer(\n",
|
784 |
+
" model, \n",
|
785 |
+
" args, \n",
|
786 |
+
" train_dataset=ds_enc[\"train\"], \n",
|
787 |
+
" eval_dataset=ds_enc[\"test\"], \n",
|
788 |
+
" compute_metrics=compute_metrics, \n",
|
789 |
+
" tokenizer=tokenizer\n",
|
790 |
+
")"
|
791 |
+
]
|
792 |
+
},
|
793 |
+
{
|
794 |
+
"cell_type": "code",
|
795 |
+
"execution_count": 71,
|
796 |
+
"metadata": {
|
797 |
+
"datalore": {
|
798 |
+
"hide_input_from_viewers": true,
|
799 |
+
"hide_output_from_viewers": true,
|
800 |
+
"node_id": "LO54PlDkWQdFrzV25FvduB",
|
801 |
+
"type": "CODE"
|
802 |
+
},
|
803 |
+
"gather": {
|
804 |
+
"logged": 1706449880674
|
805 |
+
}
|
806 |
+
},
|
807 |
+
"outputs": [
|
808 |
+
{
|
809 |
+
"data": {
|
810 |
+
"text/html": [
|
811 |
+
"Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
|
812 |
+
],
|
813 |
+
"text/plain": [
|
814 |
+
"<IPython.core.display.HTML object>"
|
815 |
+
]
|
816 |
+
},
|
817 |
+
"metadata": {},
|
818 |
+
"output_type": "display_data"
|
819 |
+
},
|
820 |
+
{
|
821 |
+
"data": {
|
822 |
+
"text/html": [
|
823 |
+
"Changes to your `wandb` environment variables will be ignored because your `wandb` session has already started. For more information on how to modify your settings with `wandb.init()` arguments, please refer to <a href='https://wandb.me/wandb-init' target=\"_blank\">the W&B docs</a>."
|
824 |
+
],
|
825 |
+
"text/plain": [
|
826 |
+
"<IPython.core.display.HTML object>"
|
827 |
+
]
|
828 |
+
},
|
829 |
+
"metadata": {},
|
830 |
+
"output_type": "display_data"
|
831 |
+
},
|
832 |
+
{
|
833 |
+
"name": "stderr",
|
834 |
+
"output_type": "stream",
|
835 |
+
"text": [
|
836 |
+
"\u001b[34m\u001b[1mwandb\u001b[0m: \u001b[33mWARNING\u001b[0m wandb.init() arguments ignored because wandb magic has already been initialized\n"
|
837 |
+
]
|
838 |
+
},
|
839 |
+
{
|
840 |
+
"data": {
|
841 |
+
"text/html": [
|
842 |
+
"Tracking run with wandb version 0.16.2"
|
843 |
+
],
|
844 |
+
"text/plain": [
|
845 |
+
"<IPython.core.display.HTML object>"
|
846 |
+
]
|
847 |
+
},
|
848 |
+
"metadata": {},
|
849 |
+
"output_type": "display_data"
|
850 |
+
},
|
851 |
+
{
|
852 |
+
"data": {
|
853 |
+
"text/html": [
|
854 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141352-spfdhiij</code>"
|
855 |
+
],
|
856 |
+
"text/plain": [
|
857 |
+
"<IPython.core.display.HTML object>"
|
858 |
+
]
|
859 |
+
},
|
860 |
+
"metadata": {},
|
861 |
+
"output_type": "display_data"
|
862 |
+
},
|
863 |
+
{
|
864 |
+
"data": {
|
865 |
+
"text/html": [
|
866 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
867 |
+
],
|
868 |
+
"text/plain": [
|
869 |
+
"<IPython.core.display.HTML object>"
|
870 |
+
]
|
871 |
+
},
|
872 |
+
"metadata": {},
|
873 |
+
"output_type": "display_data"
|
874 |
+
},
|
875 |
+
{
|
876 |
+
"data": {
|
877 |
+
"text/html": [
|
878 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
|
879 |
+
],
|
880 |
+
"text/plain": [
|
881 |
+
"<IPython.core.display.HTML object>"
|
882 |
+
]
|
883 |
+
},
|
884 |
+
"metadata": {},
|
885 |
+
"output_type": "display_data"
|
886 |
+
},
|
887 |
+
{
|
888 |
+
"data": {
|
889 |
+
"text/html": [
|
890 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij</a>"
|
891 |
+
],
|
892 |
+
"text/plain": [
|
893 |
+
"<IPython.core.display.HTML object>"
|
894 |
+
]
|
895 |
+
},
|
896 |
+
"metadata": {},
|
897 |
+
"output_type": "display_data"
|
898 |
+
},
|
899 |
+
{
|
900 |
+
"data": {
|
901 |
+
"text/html": [
|
902 |
+
"Finishing last run (ID:spfdhiij) before initializing another..."
|
903 |
+
],
|
904 |
+
"text/plain": [
|
905 |
+
"<IPython.core.display.HTML object>"
|
906 |
+
]
|
907 |
+
},
|
908 |
+
"metadata": {},
|
909 |
+
"output_type": "display_data"
|
910 |
+
},
|
911 |
+
{
|
912 |
+
"data": {
|
913 |
+
"text/html": [
|
914 |
+
" View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/spfdhiij</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
915 |
+
],
|
916 |
+
"text/plain": [
|
917 |
+
"<IPython.core.display.HTML object>"
|
918 |
+
]
|
919 |
+
},
|
920 |
+
"metadata": {},
|
921 |
+
"output_type": "display_data"
|
922 |
+
},
|
923 |
+
{
|
924 |
+
"data": {
|
925 |
+
"text/html": [
|
926 |
+
"Find logs at: <code>./wandb/run-20240128_141352-spfdhiij/logs</code>"
|
927 |
+
],
|
928 |
+
"text/plain": [
|
929 |
+
"<IPython.core.display.HTML object>"
|
930 |
+
]
|
931 |
+
},
|
932 |
+
"metadata": {},
|
933 |
+
"output_type": "display_data"
|
934 |
+
},
|
935 |
+
{
|
936 |
+
"data": {
|
937 |
+
"text/html": [
|
938 |
+
"Successfully finished last run (ID:spfdhiij). Initializing new run:<br/>"
|
939 |
+
],
|
940 |
+
"text/plain": [
|
941 |
+
"<IPython.core.display.HTML object>"
|
942 |
+
]
|
943 |
+
},
|
944 |
+
"metadata": {},
|
945 |
+
"output_type": "display_data"
|
946 |
+
},
|
947 |
+
{
|
948 |
+
"data": {
|
949 |
+
"text/html": [
|
950 |
+
"Tracking run with wandb version 0.16.2"
|
951 |
+
],
|
952 |
+
"text/plain": [
|
953 |
+
"<IPython.core.display.HTML object>"
|
954 |
+
]
|
955 |
+
},
|
956 |
+
"metadata": {},
|
957 |
+
"output_type": "display_data"
|
958 |
+
},
|
959 |
+
{
|
960 |
+
"data": {
|
961 |
+
"text/html": [
|
962 |
+
"Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/cvc-vaers-bert-dnsd/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240128_141354-mpe6cpuz</code>"
|
963 |
+
],
|
964 |
+
"text/plain": [
|
965 |
+
"<IPython.core.display.HTML object>"
|
966 |
+
]
|
967 |
+
},
|
968 |
+
"metadata": {},
|
969 |
+
"output_type": "display_data"
|
970 |
+
},
|
971 |
+
{
|
972 |
+
"data": {
|
973 |
+
"text/html": [
|
974 |
+
"Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz' target=\"_blank\">init_evaluation_run</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
975 |
+
],
|
976 |
+
"text/plain": [
|
977 |
+
"<IPython.core.display.HTML object>"
|
978 |
+
]
|
979 |
+
},
|
980 |
+
"metadata": {},
|
981 |
+
"output_type": "display_data"
|
982 |
+
},
|
983 |
+
{
|
984 |
+
"data": {
|
985 |
+
"text/html": [
|
986 |
+
" View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training</a>"
|
987 |
+
],
|
988 |
+
"text/plain": [
|
989 |
+
"<IPython.core.display.HTML object>"
|
990 |
+
]
|
991 |
+
},
|
992 |
+
"metadata": {},
|
993 |
+
"output_type": "display_data"
|
994 |
+
},
|
995 |
+
{
|
996 |
+
"data": {
|
997 |
+
"text/html": [
|
998 |
+
" View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz</a>"
|
999 |
+
],
|
1000 |
+
"text/plain": [
|
1001 |
+
"<IPython.core.display.HTML object>"
|
1002 |
+
]
|
1003 |
+
},
|
1004 |
+
"metadata": {},
|
1005 |
+
"output_type": "display_data"
|
1006 |
+
},
|
1007 |
+
{
|
1008 |
+
"data": {
|
1009 |
+
"text/html": [
|
1010 |
+
"<style>\n",
|
1011 |
+
" table.wandb td:nth-child(1) { padding: 0 10px; text-align: left ; width: auto;} td:nth-child(2) {text-align: left ; width: 100%}\n",
|
1012 |
+
" .wandb-row { display: flex; flex-direction: row; flex-wrap: wrap; justify-content: flex-start; width: 100% }\n",
|
1013 |
+
" .wandb-col { display: flex; flex-direction: column; flex-basis: 100%; flex: 1; padding: 10px; }\n",
|
1014 |
+
" </style>\n",
|
1015 |
+
"<div class=\"wandb-row\"><div class=\"wandb-col\"><h3>Run history:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>▁</td></tr><tr><td>eval/loss</td><td>▁</td></tr><tr><td>eval/runtime</td><td>▁</td></tr><tr><td>eval/samples_per_second</td><td>▁</td></tr><tr><td>eval/steps_per_second</td><td>▁</td></tr><tr><td>train/global_step</td><td>▁</td></tr></table><br/></div><div class=\"wandb-col\"><h3>Run summary:</h3><br/><table class=\"wandb\"><tr><td>eval/accuracy_thresh</td><td>0.42136</td></tr><tr><td>eval/loss</td><td>0.69069</td></tr><tr><td>eval/runtime</td><td>79.1475</td></tr><tr><td>eval/samples_per_second</td><td>343.965</td></tr><tr><td>eval/steps_per_second</td><td>2.691</td></tr><tr><td>train/global_step</td><td>0</td></tr></table><br/></div></div>"
|
1016 |
+
],
|
1017 |
+
"text/plain": [
|
1018 |
+
"<IPython.core.display.HTML object>"
|
1019 |
+
]
|
1020 |
+
},
|
1021 |
+
"metadata": {},
|
1022 |
+
"output_type": "display_data"
|
1023 |
+
},
|
1024 |
+
{
|
1025 |
+
"data": {
|
1026 |
+
"text/html": [
|
1027 |
+
" View run <strong style=\"color:#cdcd00\">init_evaluation_run</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20model%20training/runs/mpe6cpuz</a><br/>Synced 5 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
1028 |
+
],
|
1029 |
+
"text/plain": [
|
1030 |
+
"<IPython.core.display.HTML object>"
|
1031 |
+
]
|
1032 |
+
},
|
1033 |
+
"metadata": {},
|
1034 |
+
"output_type": "display_data"
|
1035 |
+
},
|
1036 |
+
{
|
1037 |
+
"data": {
|
1038 |
+
"text/html": [
|
1039 |
+
"Find logs at: <code>./wandb/run-20240128_141354-mpe6cpuz/logs</code>"
|
1040 |
+
],
|
1041 |
+
"text/plain": [
|
1042 |
+
"<IPython.core.display.HTML object>"
|
1043 |
+
]
|
1044 |
+
},
|
1045 |
+
"metadata": {},
|
1046 |
+
"output_type": "display_data"
|
1047 |
+
}
|
1048 |
+
],
|
1049 |
+
"source": [
|
1050 |
+
"if SUBSAMPLING != 1.0:\n",
|
1051 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
1052 |
+
"else:\n",
|
1053 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
1054 |
+
" \n",
|
1055 |
+
"wandb.init(name=\"init_evaluation_run\", tags=wandb_tag, magic=True)\n",
|
1056 |
+
"\n",
|
1057 |
+
"multi_label_trainer.evaluate()\n",
|
1058 |
+
"wandb.finish()"
|
1059 |
+
]
|
1060 |
+
},
|
1061 |
+
{
|
1062 |
+
"cell_type": "code",
|
1063 |
+
"execution_count": 62,
|
1064 |
+
"metadata": {
|
1065 |
+
"datalore": {
|
1066 |
+
"hide_input_from_viewers": true,
|
1067 |
+
"hide_output_from_viewers": true,
|
1068 |
+
"node_id": "hf0Ei1QXEYDmBv1VNLZ4Zw",
|
1069 |
+
"type": "CODE"
|
1070 |
+
},
|
1071 |
+
"gather": {
|
1072 |
+
"logged": 1706449934637
|
1073 |
+
}
|
1074 |
+
},
|
1075 |
+
"outputs": [
|
1076 |
+
{
|
1077 |
+
"ename": "RuntimeError",
|
1078 |
+
"evalue": "Caught RuntimeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py\", line 61, in _worker\n output = module(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 1002, in forward\n distilbert_output = self.distilbert(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 822, in forward\n return self.transformer(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 587, in forward\n layer_outputs = layer_module(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 513, in forward\n sa_output = self.attention(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 243, in forward\n scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)\nRuntimeError: CUDA out of memory. Tried to allocate 96.00 MiB (GPU 0; 15.77 GiB total capacity; 14.69 GiB already allocated; 5.12 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n",
|
1079 |
+
"output_type": "error",
|
1080 |
+
"traceback": [
|
1081 |
+
"\u001b[0;31m---------------------------------------------------------------------------\u001b[0m",
|
1082 |
+
"\u001b[0;31mRuntimeError\u001b[0m Traceback (most recent call last)",
|
1083 |
+
"Cell \u001b[0;32mIn[62], line 1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[43mmulti_label_trainer\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtrain\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n",
|
1084 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:1539\u001b[0m, in \u001b[0;36mTrainer.train\u001b[0;34m(self, resume_from_checkpoint, trial, ignore_keys_for_eval, **kwargs)\u001b[0m\n\u001b[1;32m 1537\u001b[0m hf_hub_utils\u001b[38;5;241m.\u001b[39menable_progress_bars()\n\u001b[1;32m 1538\u001b[0m \u001b[38;5;28;01melse\u001b[39;00m:\n\u001b[0;32m-> 1539\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43minner_training_loop\u001b[49m\u001b[43m(\u001b[49m\n\u001b[1;32m 1540\u001b[0m \u001b[43m \u001b[49m\u001b[43margs\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43margs\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1541\u001b[0m \u001b[43m \u001b[49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mresume_from_checkpoint\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1542\u001b[0m \u001b[43m \u001b[49m\u001b[43mtrial\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mtrial\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1543\u001b[0m \u001b[43m \u001b[49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[38;5;241;43m=\u001b[39;49m\u001b[43mignore_keys_for_eval\u001b[49m\u001b[43m,\u001b[49m\n\u001b[1;32m 1544\u001b[0m \u001b[43m \u001b[49m\u001b[43m)\u001b[49m\n",
|
1085 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:1869\u001b[0m, in \u001b[0;36mTrainer._inner_training_loop\u001b[0;34m(self, batch_size, args, resume_from_checkpoint, trial, ignore_keys_for_eval)\u001b[0m\n\u001b[1;32m 1866\u001b[0m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcallback_handler\u001b[38;5;241m.\u001b[39mon_step_begin(args, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcontrol)\n\u001b[1;32m 1868\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39maccelerator\u001b[38;5;241m.\u001b[39maccumulate(model):\n\u001b[0;32m-> 1869\u001b[0m tr_loss_step \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mtraining_step\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1871\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m (\n\u001b[1;32m 1872\u001b[0m args\u001b[38;5;241m.\u001b[39mlogging_nan_inf_filter\n\u001b[1;32m 1873\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m is_torch_tpu_available()\n\u001b[1;32m 1874\u001b[0m \u001b[38;5;129;01mand\u001b[39;00m (torch\u001b[38;5;241m.\u001b[39misnan(tr_loss_step) \u001b[38;5;129;01mor\u001b[39;00m torch\u001b[38;5;241m.\u001b[39misinf(tr_loss_step))\n\u001b[1;32m 1875\u001b[0m ):\n\u001b[1;32m 1876\u001b[0m \u001b[38;5;66;03m# if loss is nan or inf simply add the average of previous logged losses\u001b[39;00m\n\u001b[1;32m 1877\u001b[0m tr_loss \u001b[38;5;241m+\u001b[39m\u001b[38;5;241m=\u001b[39m tr_loss \u001b[38;5;241m/\u001b[39m (\u001b[38;5;241m1\u001b[39m \u001b[38;5;241m+\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mstate\u001b[38;5;241m.\u001b[39mglobal_step \u001b[38;5;241m-\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_globalstep_last_logged)\n",
|
1086 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/trainer.py:2768\u001b[0m, in \u001b[0;36mTrainer.training_step\u001b[0;34m(self, model, inputs)\u001b[0m\n\u001b[1;32m 2765\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m loss_mb\u001b[38;5;241m.\u001b[39mreduce_mean()\u001b[38;5;241m.\u001b[39mdetach()\u001b[38;5;241m.\u001b[39mto(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mdevice)\n\u001b[1;32m 2767\u001b[0m \u001b[38;5;28;01mwith\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mcompute_loss_context_manager():\n\u001b[0;32m-> 2768\u001b[0m loss \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mcompute_loss\u001b[49m\u001b[43m(\u001b[49m\u001b[43mmodel\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 2770\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39margs\u001b[38;5;241m.\u001b[39mn_gpu \u001b[38;5;241m>\u001b[39m \u001b[38;5;241m1\u001b[39m:\n\u001b[1;32m 2771\u001b[0m loss \u001b[38;5;241m=\u001b[39m loss\u001b[38;5;241m.\u001b[39mmean() \u001b[38;5;66;03m# mean() to average on multi-gpu parallel training\u001b[39;00m\n",
|
1087 |
+
"Cell \u001b[0;32mIn[55], line 4\u001b[0m, in \u001b[0;36mMultiLabelTrainer.compute_loss\u001b[0;34m(self, model, inputs, return_outputs)\u001b[0m\n\u001b[1;32m 2\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mcompute_loss\u001b[39m(\u001b[38;5;28mself\u001b[39m, model, inputs, return_outputs\u001b[38;5;241m=\u001b[39m\u001b[38;5;28;01mFalse\u001b[39;00m):\n\u001b[1;32m 3\u001b[0m labels \u001b[38;5;241m=\u001b[39m inputs\u001b[38;5;241m.\u001b[39mpop(\u001b[38;5;124m\"\u001b[39m\u001b[38;5;124mlabels\u001b[39m\u001b[38;5;124m\"\u001b[39m)\n\u001b[0;32m----> 4\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[43mmodel\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43minputs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 5\u001b[0m logits \u001b[38;5;241m=\u001b[39m outputs\u001b[38;5;241m.\u001b[39mlogits\n\u001b[1;32m 6\u001b[0m loss_fct \u001b[38;5;241m=\u001b[39m torch\u001b[38;5;241m.\u001b[39mnn\u001b[38;5;241m.\u001b[39mBCEWithLogitsLoss()\n",
|
1088 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py:1130\u001b[0m, in \u001b[0;36mModule._call_impl\u001b[0;34m(self, *input, **kwargs)\u001b[0m\n\u001b[1;32m 1126\u001b[0m \u001b[38;5;66;03m# If we don't have any hooks, we want to skip the rest of the logic in\u001b[39;00m\n\u001b[1;32m 1127\u001b[0m \u001b[38;5;66;03m# this function, and just call forward.\u001b[39;00m\n\u001b[1;32m 1128\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;129;01mnot\u001b[39;00m (\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_backward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39m_forward_pre_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_backward_hooks\n\u001b[1;32m 1129\u001b[0m \u001b[38;5;129;01mor\u001b[39;00m _global_forward_hooks \u001b[38;5;129;01mor\u001b[39;00m _global_forward_pre_hooks):\n\u001b[0;32m-> 1130\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mforward_call\u001b[49m\u001b[43m(\u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;28;43minput\u001b[39;49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[38;5;241;43m*\u001b[39;49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 1131\u001b[0m \u001b[38;5;66;03m# Do not call functions when jit is used\u001b[39;00m\n\u001b[1;32m 1132\u001b[0m full_backward_hooks, non_full_backward_hooks \u001b[38;5;241m=\u001b[39m [], []\n",
|
1089 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:168\u001b[0m, in \u001b[0;36mDataParallel.forward\u001b[0;34m(self, *inputs, **kwargs)\u001b[0m\n\u001b[1;32m 166\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule(\u001b[38;5;241m*\u001b[39minputs[\u001b[38;5;241m0\u001b[39m], \u001b[38;5;241m*\u001b[39m\u001b[38;5;241m*\u001b[39mkwargs[\u001b[38;5;241m0\u001b[39m])\n\u001b[1;32m 167\u001b[0m replicas \u001b[38;5;241m=\u001b[39m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mreplicate(\u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mmodule, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mdevice_ids[:\u001b[38;5;28mlen\u001b[39m(inputs)])\n\u001b[0;32m--> 168\u001b[0m outputs \u001b[38;5;241m=\u001b[39m \u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 169\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39mgather(outputs, \u001b[38;5;28mself\u001b[39m\u001b[38;5;241m.\u001b[39moutput_device)\n",
|
1090 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/data_parallel.py:178\u001b[0m, in \u001b[0;36mDataParallel.parallel_apply\u001b[0;34m(self, replicas, inputs, kwargs)\u001b[0m\n\u001b[1;32m 177\u001b[0m \u001b[38;5;28;01mdef\u001b[39;00m \u001b[38;5;21mparallel_apply\u001b[39m(\u001b[38;5;28mself\u001b[39m, replicas, inputs, kwargs):\n\u001b[0;32m--> 178\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m \u001b[43mparallel_apply\u001b[49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43minputs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[43mkwargs\u001b[49m\u001b[43m,\u001b[49m\u001b[43m \u001b[49m\u001b[38;5;28;43mself\u001b[39;49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mdevice_ids\u001b[49m\u001b[43m[\u001b[49m\u001b[43m:\u001b[49m\u001b[38;5;28;43mlen\u001b[39;49m\u001b[43m(\u001b[49m\u001b[43mreplicas\u001b[49m\u001b[43m)\u001b[49m\u001b[43m]\u001b[49m\u001b[43m)\u001b[49m\n",
|
1091 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py:86\u001b[0m, in \u001b[0;36mparallel_apply\u001b[0;34m(modules, inputs, kwargs_tup, devices)\u001b[0m\n\u001b[1;32m 84\u001b[0m output \u001b[38;5;241m=\u001b[39m results[i]\n\u001b[1;32m 85\u001b[0m \u001b[38;5;28;01mif\u001b[39;00m \u001b[38;5;28misinstance\u001b[39m(output, ExceptionWrapper):\n\u001b[0;32m---> 86\u001b[0m \u001b[43moutput\u001b[49m\u001b[38;5;241;43m.\u001b[39;49m\u001b[43mreraise\u001b[49m\u001b[43m(\u001b[49m\u001b[43m)\u001b[49m\n\u001b[1;32m 87\u001b[0m outputs\u001b[38;5;241m.\u001b[39mappend(output)\n\u001b[1;32m 88\u001b[0m \u001b[38;5;28;01mreturn\u001b[39;00m outputs\n",
|
1092 |
+
"File \u001b[0;32m/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/_utils.py:461\u001b[0m, in \u001b[0;36mExceptionWrapper.reraise\u001b[0;34m(self)\u001b[0m\n\u001b[1;32m 457\u001b[0m \u001b[38;5;28;01mexcept\u001b[39;00m \u001b[38;5;167;01mTypeError\u001b[39;00m:\n\u001b[1;32m 458\u001b[0m \u001b[38;5;66;03m# If the exception takes multiple arguments, don't try to\u001b[39;00m\n\u001b[1;32m 459\u001b[0m \u001b[38;5;66;03m# instantiate since we don't know how to\u001b[39;00m\n\u001b[1;32m 460\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m \u001b[38;5;167;01mRuntimeError\u001b[39;00m(msg) \u001b[38;5;28;01mfrom\u001b[39;00m \u001b[38;5;28;01mNone\u001b[39;00m\n\u001b[0;32m--> 461\u001b[0m \u001b[38;5;28;01mraise\u001b[39;00m exception\n",
|
1093 |
+
"\u001b[0;31mRuntimeError\u001b[0m: Caught RuntimeError in replica 0 on device 0.\nOriginal Traceback (most recent call last):\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/parallel/parallel_apply.py\", line 61, in _worker\n output = module(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 1002, in forward\n distilbert_output = self.distilbert(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 822, in forward\n return self.transformer(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 587, in forward\n layer_outputs = layer_module(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 513, in forward\n sa_output = self.attention(\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/torch/nn/modules/module.py\", line 1130, in _call_impl\n return forward_call(*input, **kwargs)\n File \"/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/transformers/models/distilbert/modeling_distilbert.py\", line 243, in forward\n scores = torch.matmul(q, k.transpose(2, 3)) # (bs, n_heads, q_length, k_length)\nRuntimeError: CUDA out of memory. Tried to allocate 96.00 MiB (GPU 0; 15.77 GiB total capacity; 14.69 GiB already allocated; 5.12 MiB free; 14.72 GiB reserved in total by PyTorch) If reserved memory is >> allocated memory try setting max_split_size_mb to avoid fragmentation. See documentation for Memory Management and PYTORCH_CUDA_ALLOC_CONF\n"
|
1094 |
+
]
|
1095 |
+
}
|
1096 |
+
],
|
1097 |
+
"source": [
|
1098 |
+
"if SUBSAMPLING != 1.0:\n",
|
1099 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
1100 |
+
"else:\n",
|
1101 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
1102 |
+
" \n",
|
1103 |
+
"wandb.init(name=\"daedra_training_run\", tags=wandb_tag, magic=True)\n",
|
1104 |
+
"\n",
|
1105 |
+
"multi_label_trainer.train()\n",
|
1106 |
+
"wandb.finish()"
|
1107 |
+
]
|
1108 |
+
},
|
1109 |
+
{
|
1110 |
+
"cell_type": "markdown",
|
1111 |
+
"metadata": {},
|
1112 |
+
"source": [
|
1113 |
+
"### Evaluation"
|
1114 |
+
]
|
1115 |
+
},
|
1116 |
+
{
|
1117 |
+
"cell_type": "markdown",
|
1118 |
+
"metadata": {},
|
1119 |
+
"source": [
|
1120 |
+
"We instantiate a classifier `pipeline` and push it to CUDA."
|
1121 |
+
]
|
1122 |
+
},
|
1123 |
+
{
|
1124 |
+
"cell_type": "code",
|
1125 |
+
"execution_count": null,
|
1126 |
+
"metadata": {
|
1127 |
+
"datalore": {
|
1128 |
+
"hide_input_from_viewers": true,
|
1129 |
+
"hide_output_from_viewers": true,
|
1130 |
+
"node_id": "kHoUdBeqcyVXDSGv54C4aE",
|
1131 |
+
"type": "CODE"
|
1132 |
+
},
|
1133 |
+
"gather": {
|
1134 |
+
"logged": 1706411459928
|
1135 |
+
}
|
1136 |
+
},
|
1137 |
+
"outputs": [],
|
1138 |
+
"source": [
|
1139 |
+
"classifier = pipeline(\"text-classification\", \n",
|
1140 |
+
" model, \n",
|
1141 |
+
" tokenizer=tokenizer, \n",
|
1142 |
+
" device=\"cuda:0\")"
|
1143 |
+
]
|
1144 |
+
},
|
1145 |
+
{
|
1146 |
+
"cell_type": "markdown",
|
1147 |
+
"metadata": {},
|
1148 |
+
"source": [
|
1149 |
+
"We use the same tokenizer used for training to tokenize/encode the validation set."
|
1150 |
+
]
|
1151 |
+
},
|
1152 |
+
{
|
1153 |
+
"cell_type": "code",
|
1154 |
+
"execution_count": null,
|
1155 |
+
"metadata": {
|
1156 |
+
"datalore": {
|
1157 |
+
"hide_input_from_viewers": true,
|
1158 |
+
"hide_output_from_viewers": true,
|
1159 |
+
"node_id": "Dr5WCWA6jL51NR1fSrQu6Z",
|
1160 |
+
"type": "CODE"
|
1161 |
+
},
|
1162 |
+
"gather": {
|
1163 |
+
"logged": 1706411523285
|
1164 |
+
}
|
1165 |
+
},
|
1166 |
+
"outputs": [],
|
1167 |
+
"source": [
|
1168 |
+
"test_encodings = tokenizer.batch_encode_plus(dataset[\"val\"][\"text\"], \n",
|
1169 |
+
" max_length=None, \n",
|
1170 |
+
" padding='max_length', \n",
|
1171 |
+
" return_token_type_ids=True, \n",
|
1172 |
+
" truncation=True)"
|
1173 |
+
]
|
1174 |
+
},
|
1175 |
+
{
|
1176 |
+
"cell_type": "markdown",
|
1177 |
+
"metadata": {},
|
1178 |
+
"source": [
|
1179 |
+
"Once we've made the data loadable by putting it into a `DataLoader`, we "
|
1180 |
+
]
|
1181 |
+
},
|
1182 |
+
{
|
1183 |
+
"cell_type": "code",
|
1184 |
+
"execution_count": null,
|
1185 |
+
"metadata": {
|
1186 |
+
"datalore": {
|
1187 |
+
"hide_input_from_viewers": true,
|
1188 |
+
"hide_output_from_viewers": true,
|
1189 |
+
"node_id": "MWfGq2tTkJNzFiDoUPq2X7",
|
1190 |
+
"type": "CODE"
|
1191 |
+
},
|
1192 |
+
"gather": {
|
1193 |
+
"logged": 1706411543379
|
1194 |
+
}
|
1195 |
+
},
|
1196 |
+
"outputs": [],
|
1197 |
+
"source": [
|
1198 |
+
"test_data = torch.utils.data.TensorDataset(torch.tensor(test_encodings['input_ids']), \n",
|
1199 |
+
" torch.tensor(test_encodings['attention_mask']), \n",
|
1200 |
+
" torch.tensor(ds_enc[\"val\"][\"labels\"]), \n",
|
1201 |
+
" torch.tensor(test_encodings['token_type_ids']))\n",
|
1202 |
+
"test_dataloader = torch.utils.data.DataLoader(test_data, \n",
|
1203 |
+
" sampler=torch.utils.data.SequentialSampler(test_data), \n",
|
1204 |
+
" batch_size=BATCH_SIZE)"
|
1205 |
+
]
|
1206 |
+
},
|
1207 |
+
{
|
1208 |
+
"cell_type": "code",
|
1209 |
+
"execution_count": null,
|
1210 |
+
"metadata": {
|
1211 |
+
"datalore": {
|
1212 |
+
"hide_input_from_viewers": true,
|
1213 |
+
"hide_output_from_viewers": true,
|
1214 |
+
"node_id": "1SJCSrQTRCexFCNCIyRrzL",
|
1215 |
+
"type": "CODE"
|
1216 |
+
},
|
1217 |
+
"gather": {
|
1218 |
+
"logged": 1706411587843
|
1219 |
+
}
|
1220 |
+
},
|
1221 |
+
"outputs": [],
|
1222 |
+
"source": [
|
1223 |
+
"model.eval()\n",
|
1224 |
+
"\n",
|
1225 |
+
"logit_preds, true_labels, pred_labels, tokenized_texts = [], [], [], []\n",
|
1226 |
+
"\n",
|
1227 |
+
"for i, batch in enumerate(test_dataloader):\n",
|
1228 |
+
" batch = tuple(t.to(device) for t in batch)\n",
|
1229 |
+
" \n",
|
1230 |
+
" # Unpack the inputs from our dataloader\n",
|
1231 |
+
" b_input_ids, b_input_mask, b_labels, b_token_types = batch\n",
|
1232 |
+
" \n",
|
1233 |
+
" with torch.no_grad():\n",
|
1234 |
+
" outs = model(b_input_ids, attention_mask=b_input_mask)\n",
|
1235 |
+
" b_logit_pred = outs[0]\n",
|
1236 |
+
" pred_label = torch.sigmoid(b_logit_pred)\n",
|
1237 |
+
"\n",
|
1238 |
+
" b_logit_pred = b_logit_pred.detach().cpu().numpy()\n",
|
1239 |
+
" pred_label = pred_label.to('cpu').numpy()\n",
|
1240 |
+
" b_labels = b_labels.to('cpu').numpy()\n",
|
1241 |
+
"\n",
|
1242 |
+
" tokenized_texts.append(b_input_ids)\n",
|
1243 |
+
" logit_preds.append(b_logit_pred)\n",
|
1244 |
+
" true_labels.append(b_labels)\n",
|
1245 |
+
" pred_labels.append(pred_label)\n",
|
1246 |
+
"\n",
|
1247 |
+
"# Flatten outputs\n",
|
1248 |
+
"tokenized_texts = [item for sublist in tokenized_texts for item in sublist]\n",
|
1249 |
+
"pred_labels = [item for sublist in pred_labels for item in sublist]\n",
|
1250 |
+
"true_labels = [item for sublist in true_labels for item in sublist]\n",
|
1251 |
+
"\n",
|
1252 |
+
"# Converting flattened binary values to boolean values\n",
|
1253 |
+
"true_bools = [tl == 1 for tl in true_labels]\n",
|
1254 |
+
"pred_bools = [pl > 0.50 for pl in pred_labels] "
|
1255 |
+
]
|
1256 |
+
},
|
1257 |
+
{
|
1258 |
+
"cell_type": "markdown",
|
1259 |
+
"metadata": {},
|
1260 |
+
"source": [
|
1261 |
+
"We create a classification report:"
|
1262 |
+
]
|
1263 |
+
},
|
1264 |
+
{
|
1265 |
+
"cell_type": "code",
|
1266 |
+
"execution_count": null,
|
1267 |
+
"metadata": {
|
1268 |
+
"datalore": {
|
1269 |
+
"hide_input_from_viewers": true,
|
1270 |
+
"hide_output_from_viewers": true,
|
1271 |
+
"node_id": "eBprrgF086mznPbPVBpOLS",
|
1272 |
+
"type": "CODE"
|
1273 |
+
},
|
1274 |
+
"gather": {
|
1275 |
+
"logged": 1706411588249
|
1276 |
+
}
|
1277 |
+
},
|
1278 |
+
"outputs": [],
|
1279 |
+
"source": [
|
1280 |
+
"print('Test F1 Accuracy: ', f1_score(true_bools, pred_bools, average='micro'))\n",
|
1281 |
+
"print('Test Flat Accuracy: ', accuracy_score(true_bools, pred_bools), '\\n')\n",
|
1282 |
+
"clf_report = classification_report(true_bools, pred_bools, target_names=CLASS_NAMES)\n",
|
1283 |
+
"print(clf_report)"
|
1284 |
+
]
|
1285 |
+
},
|
1286 |
+
{
|
1287 |
+
"cell_type": "markdown",
|
1288 |
+
"metadata": {},
|
1289 |
+
"source": [
|
1290 |
+
"Finally, we render a 'head to head' comparison table that maps each text prediction to actual and predicted labels."
|
1291 |
+
]
|
1292 |
+
},
|
1293 |
+
{
|
1294 |
+
"cell_type": "code",
|
1295 |
+
"execution_count": null,
|
1296 |
+
"metadata": {
|
1297 |
+
"datalore": {
|
1298 |
+
"hide_input_from_viewers": true,
|
1299 |
+
"hide_output_from_viewers": true,
|
1300 |
+
"node_id": "yELHY0IEwMlMw3x6e7hoD1",
|
1301 |
+
"type": "CODE"
|
1302 |
+
},
|
1303 |
+
"gather": {
|
1304 |
+
"logged": 1706411588638
|
1305 |
+
}
|
1306 |
+
},
|
1307 |
+
"outputs": [],
|
1308 |
+
"source": [
|
1309 |
+
"# Creating a map of class names from class numbers\n",
|
1310 |
+
"idx2label = dict(zip(range(len(CLASS_NAMES)), CLASS_NAMES))"
|
1311 |
+
]
|
1312 |
+
},
|
1313 |
+
{
|
1314 |
+
"cell_type": "code",
|
1315 |
+
"execution_count": null,
|
1316 |
+
"metadata": {
|
1317 |
+
"datalore": {
|
1318 |
+
"hide_input_from_viewers": true,
|
1319 |
+
"hide_output_from_viewers": true,
|
1320 |
+
"node_id": "jH0S35dDteUch01sa6me6e",
|
1321 |
+
"type": "CODE"
|
1322 |
+
},
|
1323 |
+
"gather": {
|
1324 |
+
"logged": 1706411589004
|
1325 |
+
}
|
1326 |
+
},
|
1327 |
+
"outputs": [],
|
1328 |
+
"source": [
|
1329 |
+
"true_label_idxs, pred_label_idxs = [], []\n",
|
1330 |
+
"\n",
|
1331 |
+
"for vals in true_bools:\n",
|
1332 |
+
" true_label_idxs.append(np.where(vals)[0].flatten().tolist())\n",
|
1333 |
+
"for vals in pred_bools:\n",
|
1334 |
+
" pred_label_idxs.append(np.where(vals)[0].flatten().tolist())"
|
1335 |
+
]
|
1336 |
+
},
|
1337 |
+
{
|
1338 |
+
"cell_type": "code",
|
1339 |
+
"execution_count": null,
|
1340 |
+
"metadata": {
|
1341 |
+
"datalore": {
|
1342 |
+
"hide_input_from_viewers": true,
|
1343 |
+
"hide_output_from_viewers": true,
|
1344 |
+
"node_id": "h4vHL8XdGpayZ6xLGJUF6F",
|
1345 |
+
"type": "CODE"
|
1346 |
+
},
|
1347 |
+
"gather": {
|
1348 |
+
"logged": 1706411589301
|
1349 |
+
}
|
1350 |
+
},
|
1351 |
+
"outputs": [],
|
1352 |
+
"source": [
|
1353 |
+
"true_label_texts, pred_label_texts = [], []\n",
|
1354 |
+
"\n",
|
1355 |
+
"for vals in true_label_idxs:\n",
|
1356 |
+
" if vals:\n",
|
1357 |
+
" true_label_texts.append([idx2label[val] for val in vals])\n",
|
1358 |
+
" else:\n",
|
1359 |
+
" true_label_texts.append(vals)\n",
|
1360 |
+
"\n",
|
1361 |
+
"for vals in pred_label_idxs:\n",
|
1362 |
+
" if vals:\n",
|
1363 |
+
" pred_label_texts.append([idx2label[val] for val in vals])\n",
|
1364 |
+
" else:\n",
|
1365 |
+
" pred_label_texts.append(vals)"
|
1366 |
+
]
|
1367 |
+
},
|
1368 |
+
{
|
1369 |
+
"cell_type": "code",
|
1370 |
+
"execution_count": null,
|
1371 |
+
"metadata": {
|
1372 |
+
"datalore": {
|
1373 |
+
"hide_input_from_viewers": true,
|
1374 |
+
"hide_output_from_viewers": true,
|
1375 |
+
"node_id": "SxUmVHfQISEeptg1SawOmB",
|
1376 |
+
"type": "CODE"
|
1377 |
+
},
|
1378 |
+
"gather": {
|
1379 |
+
"logged": 1706411591952
|
1380 |
+
}
|
1381 |
+
},
|
1382 |
+
"outputs": [],
|
1383 |
+
"source": [
|
1384 |
+
"symptom_texts = [tokenizer.decode(text,\n",
|
1385 |
+
" skip_special_tokens=True,\n",
|
1386 |
+
" clean_up_tokenization_spaces=False) for text in tokenized_texts]"
|
1387 |
+
]
|
1388 |
+
},
|
1389 |
+
{
|
1390 |
+
"cell_type": "code",
|
1391 |
+
"execution_count": null,
|
1392 |
+
"metadata": {
|
1393 |
+
"datalore": {
|
1394 |
+
"hide_input_from_viewers": true,
|
1395 |
+
"hide_output_from_viewers": true,
|
1396 |
+
"node_id": "BxFNigNGRLTOqraI55BPSH",
|
1397 |
+
"type": "CODE"
|
1398 |
+
},
|
1399 |
+
"gather": {
|
1400 |
+
"logged": 1706411592512
|
1401 |
+
}
|
1402 |
+
},
|
1403 |
+
"outputs": [],
|
1404 |
+
"source": [
|
1405 |
+
"comparisons_df = pd.DataFrame({'symptom_text': symptom_texts, \n",
|
1406 |
+
" 'true_labels': true_label_texts, \n",
|
1407 |
+
" 'pred_labels':pred_label_texts})\n",
|
1408 |
+
"comparisons_df.to_csv('comparisons.csv')\n",
|
1409 |
+
"comparisons_df"
|
1410 |
+
]
|
1411 |
+
},
|
1412 |
+
{
|
1413 |
+
"cell_type": "markdown",
|
1414 |
+
"metadata": {},
|
1415 |
+
"source": [
|
1416 |
+
"### Shapley analysis"
|
1417 |
+
]
|
1418 |
+
},
|
1419 |
+
{
|
1420 |
+
"cell_type": "code",
|
1421 |
+
"execution_count": null,
|
1422 |
+
"metadata": {
|
1423 |
+
"datalore": {
|
1424 |
+
"hide_input_from_viewers": true,
|
1425 |
+
"hide_output_from_viewers": true,
|
1426 |
+
"node_id": "OpdZcoenX2HwzLdai7K5UA",
|
1427 |
+
"type": "CODE"
|
1428 |
+
},
|
1429 |
+
"gather": {
|
1430 |
+
"logged": 1706415109071
|
1431 |
+
}
|
1432 |
+
},
|
1433 |
+
"outputs": [],
|
1434 |
+
"source": [
|
1435 |
+
"explainer = shap.Explainer(classifier, output_names=CLASS_NAMES)"
|
1436 |
+
]
|
1437 |
+
},
|
1438 |
+
{
|
1439 |
+
"cell_type": "markdown",
|
1440 |
+
"metadata": {
|
1441 |
+
"nteract": {
|
1442 |
+
"transient": {
|
1443 |
+
"deleting": false
|
1444 |
+
}
|
1445 |
+
}
|
1446 |
+
},
|
1447 |
+
"source": [
|
1448 |
+
"#### Sampling correct predictions\n",
|
1449 |
+
"\n",
|
1450 |
+
"First, let's look at some correct predictions of deaths:"
|
1451 |
+
]
|
1452 |
+
},
|
1453 |
+
{
|
1454 |
+
"cell_type": "code",
|
1455 |
+
"execution_count": null,
|
1456 |
+
"metadata": {
|
1457 |
+
"collapsed": false,
|
1458 |
+
"gather": {
|
1459 |
+
"logged": 1706414973990
|
1460 |
+
},
|
1461 |
+
"jupyter": {
|
1462 |
+
"outputs_hidden": false
|
1463 |
+
},
|
1464 |
+
"nteract": {
|
1465 |
+
"transient": {
|
1466 |
+
"deleting": false
|
1467 |
+
}
|
1468 |
+
}
|
1469 |
+
},
|
1470 |
+
"outputs": [],
|
1471 |
+
"source": [
|
1472 |
+
"correct_death_predictions = comparisons_df[comparisons_df['true_labels'].astype(str) == \"['DIED']\"]"
|
1473 |
+
]
|
1474 |
+
},
|
1475 |
+
{
|
1476 |
+
"cell_type": "code",
|
1477 |
+
"execution_count": null,
|
1478 |
+
"metadata": {
|
1479 |
+
"collapsed": false,
|
1480 |
+
"gather": {
|
1481 |
+
"logged": 1706415114683
|
1482 |
+
},
|
1483 |
+
"jupyter": {
|
1484 |
+
"outputs_hidden": false
|
1485 |
+
},
|
1486 |
+
"nteract": {
|
1487 |
+
"transient": {
|
1488 |
+
"deleting": false
|
1489 |
+
}
|
1490 |
+
}
|
1491 |
+
},
|
1492 |
+
"outputs": [],
|
1493 |
+
"source": [
|
1494 |
+
"texts = [i[:512] for i in correct_death_predictions.sample(n=6).symptom_text]\n",
|
1495 |
+
"idxs = [i for i in range(len(texts))]\n",
|
1496 |
+
"\n",
|
1497 |
+
"d_s = Dataset(Table.from_arrays([idxs, texts], names=[\"idx\", \"texts\"]))"
|
1498 |
+
]
|
1499 |
+
},
|
1500 |
+
{
|
1501 |
+
"cell_type": "code",
|
1502 |
+
"execution_count": null,
|
1503 |
+
"metadata": {
|
1504 |
+
"collapsed": false,
|
1505 |
+
"gather": {
|
1506 |
+
"logged": 1706415129229
|
1507 |
+
},
|
1508 |
+
"jupyter": {
|
1509 |
+
"outputs_hidden": false
|
1510 |
+
},
|
1511 |
+
"nteract": {
|
1512 |
+
"transient": {
|
1513 |
+
"deleting": false
|
1514 |
+
}
|
1515 |
+
}
|
1516 |
+
},
|
1517 |
+
"outputs": [],
|
1518 |
+
"source": [
|
1519 |
+
"shap_values = explainer(d_s[\"texts\"])"
|
1520 |
+
]
|
1521 |
+
},
|
1522 |
+
{
|
1523 |
+
"cell_type": "code",
|
1524 |
+
"execution_count": null,
|
1525 |
+
"metadata": {
|
1526 |
+
"collapsed": false,
|
1527 |
+
"gather": {
|
1528 |
+
"logged": 1706415151494
|
1529 |
+
},
|
1530 |
+
"jupyter": {
|
1531 |
+
"outputs_hidden": false
|
1532 |
+
},
|
1533 |
+
"nteract": {
|
1534 |
+
"transient": {
|
1535 |
+
"deleting": false
|
1536 |
+
}
|
1537 |
+
}
|
1538 |
+
},
|
1539 |
+
"outputs": [],
|
1540 |
+
"source": [
|
1541 |
+
"shap.plots.text(shap_values)"
|
1542 |
+
]
|
1543 |
+
},
|
1544 |
+
{
|
1545 |
+
"cell_type": "code",
|
1546 |
+
"execution_count": null,
|
1547 |
+
"metadata": {
|
1548 |
+
"collapsed": false,
|
1549 |
+
"jupyter": {
|
1550 |
+
"outputs_hidden": false
|
1551 |
+
},
|
1552 |
+
"nteract": {
|
1553 |
+
"transient": {
|
1554 |
+
"deleting": false
|
1555 |
+
}
|
1556 |
+
}
|
1557 |
+
},
|
1558 |
+
"outputs": [],
|
1559 |
+
"source": []
|
1560 |
+
}
|
1561 |
+
],
|
1562 |
+
"metadata": {
|
1563 |
+
"datalore": {
|
1564 |
+
"base_environment": "default",
|
1565 |
+
"computation_mode": "JUPYTER",
|
1566 |
+
"package_manager": "pip",
|
1567 |
+
"packages": [
|
1568 |
+
{
|
1569 |
+
"name": "datasets",
|
1570 |
+
"source": "PIP",
|
1571 |
+
"version": "2.16.1"
|
1572 |
+
},
|
1573 |
+
{
|
1574 |
+
"name": "torch",
|
1575 |
+
"source": "PIP",
|
1576 |
+
"version": "2.1.2"
|
1577 |
+
},
|
1578 |
+
{
|
1579 |
+
"name": "accelerate",
|
1580 |
+
"source": "PIP",
|
1581 |
+
"version": "0.26.1"
|
1582 |
+
}
|
1583 |
+
],
|
1584 |
+
"report_row_ids": [
|
1585 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
1586 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
1587 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
1588 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
1589 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
1590 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
1591 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
1592 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
1593 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
1594 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
1595 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
1596 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
1597 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
1598 |
+
],
|
1599 |
+
"version": 3
|
1600 |
+
},
|
1601 |
+
"kernelspec": {
|
1602 |
+
"display_name": "Python 3.8 - Pytorch and Tensorflow",
|
1603 |
+
"language": "python",
|
1604 |
+
"name": "python38-azureml-pt-tf"
|
1605 |
+
},
|
1606 |
+
"language_info": {
|
1607 |
+
"codemirror_mode": {
|
1608 |
+
"name": "ipython",
|
1609 |
+
"version": 3
|
1610 |
+
},
|
1611 |
+
"file_extension": ".py",
|
1612 |
+
"mimetype": "text/x-python",
|
1613 |
+
"name": "python",
|
1614 |
+
"nbconvert_exporter": "python",
|
1615 |
+
"pygments_lexer": "ipython3",
|
1616 |
+
"version": "3.8.5"
|
1617 |
+
},
|
1618 |
+
"microsoft": {
|
1619 |
+
"host": {
|
1620 |
+
"AzureML": {
|
1621 |
+
"notebookHasBeenCompleted": true
|
1622 |
+
}
|
1623 |
+
},
|
1624 |
+
"ms_spell_check": {
|
1625 |
+
"ms_spell_check_language": "en"
|
1626 |
+
}
|
1627 |
+
},
|
1628 |
+
"nteract": {
|
1629 |
+
"version": "nteract-front-end@1.0.0"
|
1630 |
+
}
|
1631 |
+
},
|
1632 |
+
"nbformat": 4,
|
1633 |
+
"nbformat_minor": 4
|
1634 |
+
}
|
notebooks/DAEDRA.ipynb
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {}
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"source": [
|
15 |
+
"%pip install accelerate -U"
|
16 |
+
],
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"output_type": "stream",
|
20 |
+
"name": "stdout",
|
21 |
+
"text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"execution_count": 1,
|
25 |
+
"metadata": {
|
26 |
+
"gather": {
|
27 |
+
"logged": 1706475754655
|
28 |
+
},
|
29 |
+
"nteract": {
|
30 |
+
"transient": {
|
31 |
+
"deleting": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"tags": []
|
35 |
+
}
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"source": [
|
40 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
41 |
+
],
|
42 |
+
"outputs": [
|
43 |
+
{
|
44 |
+
"output_type": "stream",
|
45 |
+
"name": "stdout",
|
46 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"execution_count": 2,
|
50 |
+
"metadata": {
|
51 |
+
"nteract": {
|
52 |
+
"transient": {
|
53 |
+
"deleting": false
|
54 |
+
}
|
55 |
+
}
|
56 |
+
}
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"import pandas as pd\n",
|
62 |
+
"import numpy as np\n",
|
63 |
+
"import torch\n",
|
64 |
+
"import os\n",
|
65 |
+
"from typing import List, Union\n",
|
66 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
|
67 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
68 |
+
"import shap\n",
|
69 |
+
"import wandb\n",
|
70 |
+
"import evaluate\n",
|
71 |
+
"import logging\n",
|
72 |
+
"\n",
|
73 |
+
"wandb.finish()\n",
|
74 |
+
"\n",
|
75 |
+
"\n",
|
76 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
77 |
+
"\n",
|
78 |
+
"%load_ext watermark"
|
79 |
+
],
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"output_type": "stream",
|
83 |
+
"name": "stderr",
|
84 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 17:46:15.020290: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 17:46:16.031641: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031779: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
85 |
+
}
|
86 |
+
],
|
87 |
+
"execution_count": 3,
|
88 |
+
"metadata": {
|
89 |
+
"datalore": {
|
90 |
+
"hide_input_from_viewers": false,
|
91 |
+
"hide_output_from_viewers": false,
|
92 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
93 |
+
"report_properties": {
|
94 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
95 |
+
},
|
96 |
+
"type": "CODE"
|
97 |
+
},
|
98 |
+
"gather": {
|
99 |
+
"logged": 1706550378660
|
100 |
+
},
|
101 |
+
"tags": []
|
102 |
+
}
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"source": [
|
107 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
108 |
+
"\n",
|
109 |
+
"SEED: int = 42\n",
|
110 |
+
"\n",
|
111 |
+
"BATCH_SIZE: int = 32\n",
|
112 |
+
"EPOCHS: int = 5\n",
|
113 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
114 |
+
"\n",
|
115 |
+
"# WandB configuration\n",
|
116 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
117 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
118 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
119 |
+
],
|
120 |
+
"outputs": [],
|
121 |
+
"execution_count": 4,
|
122 |
+
"metadata": {
|
123 |
+
"collapsed": false,
|
124 |
+
"gather": {
|
125 |
+
"logged": 1706550378812
|
126 |
+
},
|
127 |
+
"jupyter": {
|
128 |
+
"outputs_hidden": false
|
129 |
+
}
|
130 |
+
}
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"source": [
|
135 |
+
"%watermark --iversion"
|
136 |
+
],
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"output_type": "stream",
|
140 |
+
"name": "stdout",
|
141 |
+
"text": "shap : 0.44.1\npandas : 2.0.2\nwandb : 0.16.2\nre : 2.2.1\nevaluate: 0.4.1\ntorch : 1.12.0\nnumpy : 1.23.5\nlogging : 0.5.1.2\n\n"
|
142 |
+
}
|
143 |
+
],
|
144 |
+
"execution_count": 5,
|
145 |
+
"metadata": {
|
146 |
+
"collapsed": false,
|
147 |
+
"jupyter": {
|
148 |
+
"outputs_hidden": false
|
149 |
+
}
|
150 |
+
}
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"source": [
|
155 |
+
"!nvidia-smi"
|
156 |
+
],
|
157 |
+
"outputs": [
|
158 |
+
{
|
159 |
+
"output_type": "stream",
|
160 |
+
"name": "stdout",
|
161 |
+
"text": "Mon Jan 29 17:46:18 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
162 |
+
}
|
163 |
+
],
|
164 |
+
"execution_count": 6,
|
165 |
+
"metadata": {
|
166 |
+
"datalore": {
|
167 |
+
"hide_input_from_viewers": true,
|
168 |
+
"hide_output_from_viewers": true,
|
169 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
170 |
+
"type": "CODE"
|
171 |
+
}
|
172 |
+
}
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "markdown",
|
176 |
+
"source": [
|
177 |
+
"## Loading the data set"
|
178 |
+
],
|
179 |
+
"metadata": {
|
180 |
+
"datalore": {
|
181 |
+
"hide_input_from_viewers": false,
|
182 |
+
"hide_output_from_viewers": false,
|
183 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
184 |
+
"report_properties": {
|
185 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
186 |
+
},
|
187 |
+
"type": "MD"
|
188 |
+
}
|
189 |
+
}
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"source": [
|
194 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
195 |
+
],
|
196 |
+
"outputs": [],
|
197 |
+
"execution_count": 7,
|
198 |
+
"metadata": {
|
199 |
+
"collapsed": false,
|
200 |
+
"gather": {
|
201 |
+
"logged": 1706550381141
|
202 |
+
},
|
203 |
+
"jupyter": {
|
204 |
+
"outputs_hidden": false
|
205 |
+
}
|
206 |
+
}
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"source": [
|
211 |
+
"dataset"
|
212 |
+
],
|
213 |
+
"outputs": [
|
214 |
+
{
|
215 |
+
"output_type": "execute_result",
|
216 |
+
"execution_count": 8,
|
217 |
+
"data": {
|
218 |
+
"text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
|
219 |
+
},
|
220 |
+
"metadata": {}
|
221 |
+
}
|
222 |
+
],
|
223 |
+
"execution_count": 8,
|
224 |
+
"metadata": {
|
225 |
+
"collapsed": false,
|
226 |
+
"gather": {
|
227 |
+
"logged": 1706550381303
|
228 |
+
},
|
229 |
+
"jupyter": {
|
230 |
+
"outputs_hidden": false,
|
231 |
+
"source_hidden": false
|
232 |
+
},
|
233 |
+
"nteract": {
|
234 |
+
"transient": {
|
235 |
+
"deleting": false
|
236 |
+
}
|
237 |
+
}
|
238 |
+
}
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "code",
|
242 |
+
"source": [
|
243 |
+
"SUBSAMPLING = 0.01\n",
|
244 |
+
"\n",
|
245 |
+
"if SUBSAMPLING < 1:\n",
|
246 |
+
" _ = DatasetDict()\n",
|
247 |
+
" for each in dataset.keys():\n",
|
248 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
249 |
+
"\n",
|
250 |
+
" dataset = _"
|
251 |
+
],
|
252 |
+
"outputs": [],
|
253 |
+
"execution_count": 9,
|
254 |
+
"metadata": {
|
255 |
+
"gather": {
|
256 |
+
"logged": 1706550381472
|
257 |
+
}
|
258 |
+
}
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "markdown",
|
262 |
+
"source": [
|
263 |
+
"## Tokenisation and encoding"
|
264 |
+
],
|
265 |
+
"metadata": {}
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"source": [
|
270 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
271 |
+
" return ds_enc"
|
272 |
+
],
|
273 |
+
"outputs": [],
|
274 |
+
"execution_count": 10,
|
275 |
+
"metadata": {
|
276 |
+
"gather": {
|
277 |
+
"logged": 1706550381637
|
278 |
+
}
|
279 |
+
}
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "markdown",
|
283 |
+
"source": [
|
284 |
+
"## Evaluation metrics"
|
285 |
+
],
|
286 |
+
"metadata": {}
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"source": [
|
291 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
292 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
293 |
+
"f1 = evaluate.load(\"f1\")"
|
294 |
+
],
|
295 |
+
"outputs": [],
|
296 |
+
"execution_count": 11,
|
297 |
+
"metadata": {
|
298 |
+
"gather": {
|
299 |
+
"logged": 1706550381778
|
300 |
+
}
|
301 |
+
}
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"source": [
|
306 |
+
"def compute_metrics(eval_pred):\n",
|
307 |
+
" predictions, labels = eval_pred\n",
|
308 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
309 |
+
" return {\n",
|
310 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
311 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
312 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
313 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
314 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
315 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
316 |
+
" }"
|
317 |
+
],
|
318 |
+
"outputs": [],
|
319 |
+
"execution_count": 12,
|
320 |
+
"metadata": {
|
321 |
+
"gather": {
|
322 |
+
"logged": 1706550381891
|
323 |
+
}
|
324 |
+
}
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"source": [
|
329 |
+
"## Training"
|
330 |
+
],
|
331 |
+
"metadata": {}
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "markdown",
|
335 |
+
"source": [
|
336 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
337 |
+
],
|
338 |
+
"metadata": {}
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"source": [
|
343 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
344 |
+
],
|
345 |
+
"outputs": [],
|
346 |
+
"execution_count": 13,
|
347 |
+
"metadata": {
|
348 |
+
"gather": {
|
349 |
+
"logged": 1706550382032
|
350 |
+
}
|
351 |
+
}
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"source": [
|
356 |
+
"def train_from_model(model_ckpt: str, push: bool = False):\n",
|
357 |
+
" print(f\"Initialising training based on {model_ckpt}...\")\n",
|
358 |
+
"\n",
|
359 |
+
" print(\"Tokenising...\")\n",
|
360 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
361 |
+
"\n",
|
362 |
+
" cols = dataset[\"train\"].column_names\n",
|
363 |
+
" cols.remove(\"label\")\n",
|
364 |
+
" ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
|
365 |
+
"\n",
|
366 |
+
" print(\"Loading model...\")\n",
|
367 |
+
" try:\n",
|
368 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
369 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
370 |
+
" id2label=label_map, \n",
|
371 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
372 |
+
" except OSError:\n",
|
373 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
374 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
375 |
+
" id2label=label_map, \n",
|
376 |
+
" label2id={v:k for k,v in label_map.items()},\n",
|
377 |
+
" from_tf=True)\n",
|
378 |
+
"\n",
|
379 |
+
"\n",
|
380 |
+
" args = TrainingArguments(\n",
|
381 |
+
" output_dir=\"vaers\",\n",
|
382 |
+
" evaluation_strategy=\"epoch\",\n",
|
383 |
+
" save_strategy=\"epoch\",\n",
|
384 |
+
" learning_rate=2e-5,\n",
|
385 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
386 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
387 |
+
" num_train_epochs=EPOCHS,\n",
|
388 |
+
" weight_decay=.01,\n",
|
389 |
+
" logging_steps=1,\n",
|
390 |
+
" load_best_model_at_end=True,\n",
|
391 |
+
" run_name=f\"daedra-training\",\n",
|
392 |
+
" report_to=[\"wandb\"])\n",
|
393 |
+
"\n",
|
394 |
+
" trainer = Trainer(\n",
|
395 |
+
" model=model,\n",
|
396 |
+
" args=args,\n",
|
397 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
398 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
399 |
+
" tokenizer=tokenizer,\n",
|
400 |
+
" compute_metrics=compute_metrics)\n",
|
401 |
+
" \n",
|
402 |
+
" if SUBSAMPLING != 1.0:\n",
|
403 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
404 |
+
" else:\n",
|
405 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
406 |
+
"\n",
|
407 |
+
" wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
408 |
+
" wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
409 |
+
" \n",
|
410 |
+
" wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
|
411 |
+
"\n",
|
412 |
+
" print(\"Starting training...\")\n",
|
413 |
+
"\n",
|
414 |
+
" trainer.train()\n",
|
415 |
+
"\n",
|
416 |
+
" print(\"Training finished.\")\n",
|
417 |
+
"\n",
|
418 |
+
" if push:\n",
|
419 |
+
" variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
420 |
+
" tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
421 |
+
" tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
422 |
+
" sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
423 |
+
"\n",
|
424 |
+
" model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
425 |
+
" variant=variant,\n",
|
426 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
|
427 |
+
],
|
428 |
+
"outputs": [],
|
429 |
+
"execution_count": 14,
|
430 |
+
"metadata": {
|
431 |
+
"jupyter": {
|
432 |
+
"outputs_hidden": false,
|
433 |
+
"source_hidden": false
|
434 |
+
},
|
435 |
+
"nteract": {
|
436 |
+
"transient": {
|
437 |
+
"deleting": false
|
438 |
+
}
|
439 |
+
},
|
440 |
+
"gather": {
|
441 |
+
"logged": 1706550382160
|
442 |
+
}
|
443 |
+
}
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "code",
|
447 |
+
"source": [
|
448 |
+
"\n",
|
449 |
+
"base_models = [\n",
|
450 |
+
" \"bert-base-uncased\",\n",
|
451 |
+
" \"distilbert-base-uncased\",\n",
|
452 |
+
"]"
|
453 |
+
],
|
454 |
+
"outputs": [],
|
455 |
+
"execution_count": 15,
|
456 |
+
"metadata": {
|
457 |
+
"gather": {
|
458 |
+
"logged": 1706550382318
|
459 |
+
}
|
460 |
+
}
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"cell_type": "code",
|
464 |
+
"source": [
|
465 |
+
"BATCH_SIZE=1\n",
|
466 |
+
"\n",
|
467 |
+
"train_from_model(\"biobert/Bio_ClinicalBERT/\")"
|
468 |
+
],
|
469 |
+
"outputs": [
|
470 |
+
{
|
471 |
+
"output_type": "stream",
|
472 |
+
"name": "stdout",
|
473 |
+
"text": "Initialising training based on biobert/Bio_ClinicalBERT/...\nTokenising...\nLoading model...\n"
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"output_type": "stream",
|
477 |
+
"name": "stderr",
|
478 |
+
"text": "Map: 100%|██████████| 2722/2722 [00:01<00:00, 2195.12 examples/s]\nAll TF 2.0 model weights were used when initializing BertForSequenceClassification.\n\nAll the weights of BertForSequenceClassification were initialized from the TF 2.0 model.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.\n"
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"output_type": "display_data",
|
482 |
+
"data": {
|
483 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
484 |
+
"text/html": "Finishing last run (ID:sg022tqh) before initializing another..."
|
485 |
+
},
|
486 |
+
"metadata": {}
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"output_type": "display_data",
|
490 |
+
"data": {
|
491 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
492 |
+
"text/html": " View run <strong style=\"color:#cdcd00\">daedra_0.01-biobert/Bio_ClinicalBERT/</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
493 |
+
},
|
494 |
+
"metadata": {}
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"output_type": "display_data",
|
498 |
+
"data": {
|
499 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
500 |
+
"text/html": "Find logs at: <code>./wandb/run-20240129_174816-sg022tqh/logs</code>"
|
501 |
+
},
|
502 |
+
"metadata": {}
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"output_type": "display_data",
|
506 |
+
"data": {
|
507 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
508 |
+
"text/html": "Successfully finished last run (ID:sg022tqh). Initializing new run:<br/>"
|
509 |
+
},
|
510 |
+
"metadata": {}
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"output_type": "display_data",
|
514 |
+
"data": {
|
515 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
516 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
517 |
+
},
|
518 |
+
"metadata": {}
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"output_type": "display_data",
|
522 |
+
"data": {
|
523 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
524 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_174936-kilkkg1j</code>"
|
525 |
+
},
|
526 |
+
"metadata": {}
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"output_type": "display_data",
|
530 |
+
"data": {
|
531 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
532 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">daedra_0.01-biobert/Bio_ClinicalBERT/</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
533 |
+
},
|
534 |
+
"metadata": {}
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"output_type": "display_data",
|
538 |
+
"data": {
|
539 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
540 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
541 |
+
},
|
542 |
+
"metadata": {}
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"output_type": "display_data",
|
546 |
+
"data": {
|
547 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
548 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j</a>"
|
549 |
+
},
|
550 |
+
"metadata": {}
|
551 |
+
},
|
552 |
+
{
|
553 |
+
"output_type": "stream",
|
554 |
+
"name": "stdout",
|
555 |
+
"text": "Starting training...\n"
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"output_type": "stream",
|
559 |
+
"name": "stderr",
|
560 |
+
"text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"output_type": "display_data",
|
564 |
+
"data": {
|
565 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
566 |
+
"text/html": "\n <div>\n \n <progress value='1496' max='15880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1496/15880 07:43 < 1:14:19, 3.23 it/s, Epoch 0.47/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
|
567 |
+
},
|
568 |
+
"metadata": {}
|
569 |
+
}
|
570 |
+
],
|
571 |
+
"execution_count": 21,
|
572 |
+
"metadata": {
|
573 |
+
"gather": {
|
574 |
+
"logged": 1706551053473
|
575 |
+
}
|
576 |
+
}
|
577 |
+
},
|
578 |
+
{
|
579 |
+
"cell_type": "code",
|
580 |
+
"source": [],
|
581 |
+
"outputs": [],
|
582 |
+
"execution_count": null,
|
583 |
+
"metadata": {
|
584 |
+
"jupyter": {
|
585 |
+
"source_hidden": false,
|
586 |
+
"outputs_hidden": false
|
587 |
+
},
|
588 |
+
"nteract": {
|
589 |
+
"transient": {
|
590 |
+
"deleting": false
|
591 |
+
}
|
592 |
+
}
|
593 |
+
}
|
594 |
+
}
|
595 |
+
],
|
596 |
+
"metadata": {
|
597 |
+
"datalore": {
|
598 |
+
"base_environment": "default",
|
599 |
+
"computation_mode": "JUPYTER",
|
600 |
+
"package_manager": "pip",
|
601 |
+
"packages": [
|
602 |
+
{
|
603 |
+
"name": "datasets",
|
604 |
+
"source": "PIP",
|
605 |
+
"version": "2.16.1"
|
606 |
+
},
|
607 |
+
{
|
608 |
+
"name": "torch",
|
609 |
+
"source": "PIP",
|
610 |
+
"version": "2.1.2"
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"name": "accelerate",
|
614 |
+
"source": "PIP",
|
615 |
+
"version": "0.26.1"
|
616 |
+
}
|
617 |
+
],
|
618 |
+
"report_row_ids": [
|
619 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
620 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
621 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
622 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
623 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
624 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
625 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
626 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
627 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
628 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
629 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
630 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
631 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
632 |
+
],
|
633 |
+
"version": 3
|
634 |
+
},
|
635 |
+
"kernel_info": {
|
636 |
+
"name": "python38-azureml-pt-tf"
|
637 |
+
},
|
638 |
+
"kernelspec": {
|
639 |
+
"display_name": "azureml_py38_PT_TF",
|
640 |
+
"language": "python",
|
641 |
+
"name": "python3"
|
642 |
+
},
|
643 |
+
"language_info": {
|
644 |
+
"name": "python",
|
645 |
+
"version": "3.8.5",
|
646 |
+
"mimetype": "text/x-python",
|
647 |
+
"codemirror_mode": {
|
648 |
+
"name": "ipython",
|
649 |
+
"version": 3
|
650 |
+
},
|
651 |
+
"pygments_lexer": "ipython3",
|
652 |
+
"nbconvert_exporter": "python",
|
653 |
+
"file_extension": ".py"
|
654 |
+
},
|
655 |
+
"microsoft": {
|
656 |
+
"host": {
|
657 |
+
"AzureML": {
|
658 |
+
"notebookHasBeenCompleted": true
|
659 |
+
}
|
660 |
+
},
|
661 |
+
"ms_spell_check": {
|
662 |
+
"ms_spell_check_language": "en"
|
663 |
+
}
|
664 |
+
},
|
665 |
+
"nteract": {
|
666 |
+
"version": "nteract-front-end@1.0.0"
|
667 |
+
}
|
668 |
+
},
|
669 |
+
"nbformat": 4,
|
670 |
+
"nbformat_minor": 4
|
671 |
+
}
|
notebooks/DAEDRA.yml
ADDED
File without changes
|
notebooks/Dataset preparation.ipynb
ADDED
@@ -0,0 +1,524 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# Dataset processing\n",
|
7 |
+
"\n",
|
8 |
+
"This notebook processes the raw csv outputs from VAERS into Huggingface datasets. It shouldn't generally need to be run by the end user. "
|
9 |
+
],
|
10 |
+
"metadata": {
|
11 |
+
"collapsed": false
|
12 |
+
},
|
13 |
+
"id": "35523bbeb2e03eae"
|
14 |
+
},
|
15 |
+
{
|
16 |
+
"cell_type": "code",
|
17 |
+
"outputs": [],
|
18 |
+
"source": [
|
19 |
+
"import pandas as pd\n",
|
20 |
+
"import datasets\n",
|
21 |
+
"import glob\n",
|
22 |
+
"import tqdm.notebook as tqdm\n",
|
23 |
+
"from sklearn.model_selection import train_test_split\n",
|
24 |
+
"from typing import Tuple\n",
|
25 |
+
"from datetime import datetime\n",
|
26 |
+
"\n",
|
27 |
+
"pd.set_option('future.no_silent_downcasting', True)"
|
28 |
+
],
|
29 |
+
"metadata": {
|
30 |
+
"collapsed": false,
|
31 |
+
"ExecuteTime": {
|
32 |
+
"end_time": "2024-01-27T22:28:38.481853Z",
|
33 |
+
"start_time": "2024-01-27T22:28:38.458294Z"
|
34 |
+
}
|
35 |
+
},
|
36 |
+
"id": "9362802d64424442",
|
37 |
+
"execution_count": 15
|
38 |
+
},
|
39 |
+
{
|
40 |
+
"cell_type": "code",
|
41 |
+
"outputs": [],
|
42 |
+
"source": [
|
43 |
+
"HF_URL: str = \"chrisvoncsefalvay/vaers-outcomes\"\n",
|
44 |
+
"\n",
|
45 |
+
"FLAG_COLUMNS: list = [\"DIED\", \"ER_VISIT\", \"HOSPITAL\", \"OFC_VISIT\", \"X_STAY\", \"DISABLE\"]\n",
|
46 |
+
"DEMOGRAPHIC_COLUMNS: list = [\"AGE_YRS\", \"SEX\"]\n",
|
47 |
+
"DERIVED_COLUMNS: list = [\"D_PRESENTED\"]\n",
|
48 |
+
"ID_COLUMNS: list = [\"VAERS_ID\"]\n",
|
49 |
+
"TEXT_COLUMNS: list = [\"SYMPTOM_TEXT\"]\n",
|
50 |
+
"\n",
|
51 |
+
"TEST_TRAIN_FRACTION: float = 0.3\n",
|
52 |
+
"TRAIN_VAL_FRACTION: float = 0.5"
|
53 |
+
],
|
54 |
+
"metadata": {
|
55 |
+
"collapsed": false,
|
56 |
+
"ExecuteTime": {
|
57 |
+
"end_time": "2024-01-27T22:28:38.498974Z",
|
58 |
+
"start_time": "2024-01-27T22:28:38.486237Z"
|
59 |
+
}
|
60 |
+
},
|
61 |
+
"id": "34b77edf5a1fce96",
|
62 |
+
"execution_count": 16
|
63 |
+
},
|
64 |
+
{
|
65 |
+
"cell_type": "markdown",
|
66 |
+
"source": [
|
67 |
+
"## Reading data files"
|
68 |
+
],
|
69 |
+
"metadata": {
|
70 |
+
"collapsed": false
|
71 |
+
},
|
72 |
+
"id": "f5f84ddd06e9313e"
|
73 |
+
},
|
74 |
+
{
|
75 |
+
"cell_type": "code",
|
76 |
+
"outputs": [],
|
77 |
+
"source": [
|
78 |
+
"def read_aggregate(pattern: str) -> pd.DataFrame:\n",
|
79 |
+
" files = glob.glob(f\"../data/{pattern}\")\n",
|
80 |
+
" dfs = []\n",
|
81 |
+
" for file in tqdm.tqdm(files):\n",
|
82 |
+
" dfs.append(pd.read_csv(file, encoding=\"latin-1\", low_memory=False))\n",
|
83 |
+
"\n",
|
84 |
+
" res = pd.concat(dfs, ignore_index=True)\n",
|
85 |
+
" \n",
|
86 |
+
" print(f\"Processed {len(dfs)} files for a total of {len(res)} records.\")\n",
|
87 |
+
" \n",
|
88 |
+
" return res"
|
89 |
+
],
|
90 |
+
"metadata": {
|
91 |
+
"collapsed": false,
|
92 |
+
"ExecuteTime": {
|
93 |
+
"end_time": "2024-01-27T22:28:38.508227Z",
|
94 |
+
"start_time": "2024-01-27T22:28:38.500697Z"
|
95 |
+
}
|
96 |
+
},
|
97 |
+
"id": "a7772ed4b4b51868",
|
98 |
+
"execution_count": 17
|
99 |
+
},
|
100 |
+
{
|
101 |
+
"cell_type": "code",
|
102 |
+
"outputs": [
|
103 |
+
{
|
104 |
+
"data": {
|
105 |
+
"text/plain": " 0%| | 0/1 [00:00<?, ?it/s]",
|
106 |
+
"application/vnd.jupyter.widget-view+json": {
|
107 |
+
"version_major": 2,
|
108 |
+
"version_minor": 0,
|
109 |
+
"model_id": "8a6919ed3c7e4c3a8885bb0991e856c7"
|
110 |
+
}
|
111 |
+
},
|
112 |
+
"metadata": {},
|
113 |
+
"output_type": "display_data"
|
114 |
+
},
|
115 |
+
{
|
116 |
+
"name": "stdout",
|
117 |
+
"output_type": "stream",
|
118 |
+
"text": [
|
119 |
+
"Processed 1 files for a total of 105726 records.\n"
|
120 |
+
]
|
121 |
+
}
|
122 |
+
],
|
123 |
+
"source": [
|
124 |
+
"data = read_aggregate(\"*VAERSDATA.csv\")"
|
125 |
+
],
|
126 |
+
"metadata": {
|
127 |
+
"collapsed": false,
|
128 |
+
"ExecuteTime": {
|
129 |
+
"end_time": "2024-01-27T22:28:39.567031Z",
|
130 |
+
"start_time": "2024-01-27T22:28:38.510939Z"
|
131 |
+
}
|
132 |
+
},
|
133 |
+
"id": "795e389489cbc6cf",
|
134 |
+
"execution_count": 18
|
135 |
+
},
|
136 |
+
{
|
137 |
+
"cell_type": "code",
|
138 |
+
"outputs": [],
|
139 |
+
"source": [
|
140 |
+
"_keep: list = ID_COLUMNS + DEMOGRAPHIC_COLUMNS + TEXT_COLUMNS + FLAG_COLUMNS + [\"ER_ED_VISIT\"]\n",
|
141 |
+
"data = data[_keep]"
|
142 |
+
],
|
143 |
+
"metadata": {
|
144 |
+
"collapsed": false,
|
145 |
+
"ExecuteTime": {
|
146 |
+
"end_time": "2024-01-27T22:28:39.603326Z",
|
147 |
+
"start_time": "2024-01-27T22:28:39.569131Z"
|
148 |
+
}
|
149 |
+
},
|
150 |
+
"id": "5297fca83e18b502",
|
151 |
+
"execution_count": 19
|
152 |
+
},
|
153 |
+
{
|
154 |
+
"cell_type": "markdown",
|
155 |
+
"source": [
|
156 |
+
"## Recoding\n",
|
157 |
+
"\n",
|
158 |
+
"We recode as follows:\n",
|
159 |
+
"\n",
|
160 |
+
"* For the outcome flags, `NaN` is recoded as `0` and `Y` is recoded as `1`.\n",
|
161 |
+
"* `ER_VISIT` and `ER_ED_VISIT` are coalesced into a single column called `ER_VISIT` that is `1`-valued if either is `1`-valued, otherwise it is `0`-valued. This is to manage the renaming of the column in the VAERS data.\n",
|
162 |
+
"* `NaN`s in the symptom text will drop the record."
|
163 |
+
],
|
164 |
+
"metadata": {
|
165 |
+
"collapsed": false
|
166 |
+
},
|
167 |
+
"id": "9467a8081810458e"
|
168 |
+
},
|
169 |
+
{
|
170 |
+
"cell_type": "code",
|
171 |
+
"outputs": [],
|
172 |
+
"source": [
|
173 |
+
"def recode(df: pd.DataFrame) -> pd.DataFrame:\n",
|
174 |
+
" for column in FLAG_COLUMNS + [\"ER_ED_VISIT\"]:\n",
|
175 |
+
" df[column] = df[column].replace(\"Y\", 1).fillna(0).astype(int)\n",
|
176 |
+
" \n",
|
177 |
+
" df['ER_VISIT'] = df[['ER_VISIT', 'ER_ED_VISIT']].max(axis=1)\n",
|
178 |
+
" \n",
|
179 |
+
" df = df.drop(columns=['ER_ED_VISIT'])\n",
|
180 |
+
" \n",
|
181 |
+
" df = df.dropna(subset=['SYMPTOM_TEXT'])\n",
|
182 |
+
" \n",
|
183 |
+
" return df"
|
184 |
+
],
|
185 |
+
"metadata": {
|
186 |
+
"collapsed": false,
|
187 |
+
"ExecuteTime": {
|
188 |
+
"end_time": "2024-01-27T22:28:39.603731Z",
|
189 |
+
"start_time": "2024-01-27T22:28:39.590617Z"
|
190 |
+
}
|
191 |
+
},
|
192 |
+
"id": "9aad00c9fe40adb8",
|
193 |
+
"execution_count": 20
|
194 |
+
},
|
195 |
+
{
|
196 |
+
"cell_type": "code",
|
197 |
+
"outputs": [],
|
198 |
+
"source": [],
|
199 |
+
"metadata": {
|
200 |
+
"collapsed": false,
|
201 |
+
"ExecuteTime": {
|
202 |
+
"end_time": "2024-01-27T22:28:39.604024Z",
|
203 |
+
"start_time": "2024-01-27T22:28:39.593891Z"
|
204 |
+
}
|
205 |
+
},
|
206 |
+
"id": "b0fdcab6ee807404",
|
207 |
+
"execution_count": 20
|
208 |
+
},
|
209 |
+
{
|
210 |
+
"cell_type": "code",
|
211 |
+
"outputs": [],
|
212 |
+
"source": [
|
213 |
+
"data = recode(data)"
|
214 |
+
],
|
215 |
+
"metadata": {
|
216 |
+
"collapsed": false,
|
217 |
+
"ExecuteTime": {
|
218 |
+
"end_time": "2024-01-27T22:28:39.665777Z",
|
219 |
+
"start_time": "2024-01-27T22:28:39.597946Z"
|
220 |
+
}
|
221 |
+
},
|
222 |
+
"id": "f23ee0eae1b70387",
|
223 |
+
"execution_count": 21
|
224 |
+
},
|
225 |
+
{
|
226 |
+
"cell_type": "markdown",
|
227 |
+
"source": [
|
228 |
+
"## Derived fields\n",
|
229 |
+
"\n",
|
230 |
+
"We create the derived field `D_PRESENTED`. This is to provide a shorthand for patients who present in any way: ER, hospitalisation, office visit. It also comprises patients whose hospital stay is extended (`X_STAY`) as this is typically the consequence of presenting."
|
231 |
+
],
|
232 |
+
"metadata": {
|
233 |
+
"collapsed": false
|
234 |
+
},
|
235 |
+
"id": "1c2f6b4fc2ae630b"
|
236 |
+
},
|
237 |
+
{
|
238 |
+
"cell_type": "code",
|
239 |
+
"outputs": [],
|
240 |
+
"source": [
|
241 |
+
"data['D_PRESENTED'] = data[['ER_VISIT', 'HOSPITAL', 'OFC_VISIT', 'X_STAY']].max(axis=1)"
|
242 |
+
],
|
243 |
+
"metadata": {
|
244 |
+
"collapsed": false,
|
245 |
+
"ExecuteTime": {
|
246 |
+
"end_time": "2024-01-27T22:28:39.679534Z",
|
247 |
+
"start_time": "2024-01-27T22:28:39.667363Z"
|
248 |
+
}
|
249 |
+
},
|
250 |
+
"id": "678847c70756695e",
|
251 |
+
"execution_count": 22
|
252 |
+
},
|
253 |
+
{
|
254 |
+
"cell_type": "markdown",
|
255 |
+
"source": [
|
256 |
+
"## Test/train/validate split\n",
|
257 |
+
"\n",
|
258 |
+
"We do a stratified split by age quintile and gender into test, train and validate sets."
|
259 |
+
],
|
260 |
+
"metadata": {
|
261 |
+
"collapsed": false
|
262 |
+
},
|
263 |
+
"id": "dae902b111c8ef3c"
|
264 |
+
},
|
265 |
+
{
|
266 |
+
"cell_type": "code",
|
267 |
+
"outputs": [],
|
268 |
+
"source": [
|
269 |
+
"def stratified_split(df: pd.DataFrame, test_train_fraction: float, train_val_fraction: float, random_state: int = None) -> Tuple[pd.DataFrame, pd.DataFrame, pd.DataFrame]:\n",
|
270 |
+
" df['AGE_QUINTILE'] = pd.qcut(df['AGE_YRS'], 5, labels = False)\n",
|
271 |
+
" df['STRATIFICATION_VARIABLE'] = df['SEX'].astype(str) + \"_\" + df['AGE_QUINTILE'].astype(str)\n",
|
272 |
+
" df = df.drop(columns=['AGE_QUINTILE'])\n",
|
273 |
+
" \n",
|
274 |
+
" _, train = train_test_split(df, train_size=test_train_fraction, random_state=random_state, stratify=df.STRATIFICATION_VARIABLE)\n",
|
275 |
+
" \n",
|
276 |
+
" val, test = train_test_split(_, train_size=train_val_fraction, random_state=random_state, stratify=_.STRATIFICATION_VARIABLE)\n",
|
277 |
+
" \n",
|
278 |
+
" train = train.drop(columns=\"STRATIFICATION_VARIABLE\")\n",
|
279 |
+
" val = val.drop(columns=\"STRATIFICATION_VARIABLE\")\n",
|
280 |
+
" test = test.drop(columns=\"STRATIFICATION_VARIABLE\") \n",
|
281 |
+
" \n",
|
282 |
+
" return train, test, val"
|
283 |
+
],
|
284 |
+
"metadata": {
|
285 |
+
"collapsed": false,
|
286 |
+
"ExecuteTime": {
|
287 |
+
"end_time": "2024-01-27T22:28:39.680497Z",
|
288 |
+
"start_time": "2024-01-27T22:28:39.678055Z"
|
289 |
+
}
|
290 |
+
},
|
291 |
+
"id": "ddee47653c94ff02",
|
292 |
+
"execution_count": 23
|
293 |
+
},
|
294 |
+
{
|
295 |
+
"cell_type": "code",
|
296 |
+
"outputs": [],
|
297 |
+
"source": [
|
298 |
+
"train, test, val = stratified_split(data, TEST_TRAIN_FRACTION, TRAIN_VAL_FRACTION)"
|
299 |
+
],
|
300 |
+
"metadata": {
|
301 |
+
"collapsed": false,
|
302 |
+
"ExecuteTime": {
|
303 |
+
"end_time": "2024-01-27T22:28:39.863489Z",
|
304 |
+
"start_time": "2024-01-27T22:28:39.680464Z"
|
305 |
+
}
|
306 |
+
},
|
307 |
+
"id": "bb16aaad0127ef7d",
|
308 |
+
"execution_count": 24
|
309 |
+
},
|
310 |
+
{
|
311 |
+
"cell_type": "markdown",
|
312 |
+
"source": [
|
313 |
+
"## Converting to labels"
|
314 |
+
],
|
315 |
+
"metadata": {
|
316 |
+
"collapsed": false
|
317 |
+
},
|
318 |
+
"id": "d61bfdc4a2879905"
|
319 |
+
},
|
320 |
+
{
|
321 |
+
"cell_type": "code",
|
322 |
+
"outputs": [],
|
323 |
+
"source": [
|
324 |
+
"def convert_to_dataset(df: pd.DataFrame) -> datasets.Dataset:\n",
|
325 |
+
" df = df.loc[:, ID_COLUMNS + TEXT_COLUMNS + FLAG_COLUMNS + DERIVED_COLUMNS]\n",
|
326 |
+
" \n",
|
327 |
+
" # We create the labels – these have to be floats for multilabel classification that uses BCEWithLogitsLoss\n",
|
328 |
+
" df.loc[:, \"labels\"] = df[FLAG_COLUMNS + DERIVED_COLUMNS].values.astype(float).tolist()\n",
|
329 |
+
" \n",
|
330 |
+
" print(f\"Building dataset with the following label order: {' '.join(FLAG_COLUMNS + DERIVED_COLUMNS)}\")\n",
|
331 |
+
" \n",
|
332 |
+
" # We drop the flag columns\n",
|
333 |
+
" df = df.drop(columns=FLAG_COLUMNS).drop(columns=DERIVED_COLUMNS)\n",
|
334 |
+
" \n",
|
335 |
+
" # We rename the remaining columns\n",
|
336 |
+
" df = df.rename(columns={\"SYMPTOM_TEXT\": \"text\", \"VAERS_ID\": \"id\"})\n",
|
337 |
+
" \n",
|
338 |
+
" return datasets.Dataset.from_pandas(df, preserve_index=False)"
|
339 |
+
],
|
340 |
+
"metadata": {
|
341 |
+
"collapsed": false,
|
342 |
+
"ExecuteTime": {
|
343 |
+
"end_time": "2024-01-27T22:28:39.867392Z",
|
344 |
+
"start_time": "2024-01-27T22:28:39.864829Z"
|
345 |
+
}
|
346 |
+
},
|
347 |
+
"id": "3d602444d33b7130",
|
348 |
+
"execution_count": 25
|
349 |
+
},
|
350 |
+
{
|
351 |
+
"cell_type": "code",
|
352 |
+
"outputs": [
|
353 |
+
{
|
354 |
+
"name": "stdout",
|
355 |
+
"output_type": "stream",
|
356 |
+
"text": [
|
357 |
+
"Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n",
|
358 |
+
"Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n",
|
359 |
+
"Building dataset with the following label order: DIED ER_VISIT HOSPITAL OFC_VISIT X_STAY DISABLE D_PRESENTED\n"
|
360 |
+
]
|
361 |
+
}
|
362 |
+
],
|
363 |
+
"source": [
|
364 |
+
"ds = datasets.DatasetDict()\n",
|
365 |
+
"ds[\"train\"] = convert_to_dataset(train)\n",
|
366 |
+
"ds[\"test\"] = convert_to_dataset(test)\n",
|
367 |
+
"ds[\"val\"] = convert_to_dataset(val)"
|
368 |
+
],
|
369 |
+
"metadata": {
|
370 |
+
"collapsed": false,
|
371 |
+
"ExecuteTime": {
|
372 |
+
"end_time": "2024-01-27T22:28:40.207548Z",
|
373 |
+
"start_time": "2024-01-27T22:28:39.872665Z"
|
374 |
+
}
|
375 |
+
},
|
376 |
+
"id": "e7c854a072956ca3",
|
377 |
+
"execution_count": 26
|
378 |
+
},
|
379 |
+
{
|
380 |
+
"cell_type": "markdown",
|
381 |
+
"source": [
|
382 |
+
"## Saving to Huggingface Hub"
|
383 |
+
],
|
384 |
+
"metadata": {
|
385 |
+
"collapsed": false
|
386 |
+
},
|
387 |
+
"id": "ec0167c068238f5a"
|
388 |
+
},
|
389 |
+
{
|
390 |
+
"cell_type": "code",
|
391 |
+
"outputs": [
|
392 |
+
{
|
393 |
+
"data": {
|
394 |
+
"text/plain": "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]",
|
395 |
+
"application/vnd.jupyter.widget-view+json": {
|
396 |
+
"version_major": 2,
|
397 |
+
"version_minor": 0,
|
398 |
+
"model_id": "c196be983bbc474186dad4b75347aebb"
|
399 |
+
}
|
400 |
+
},
|
401 |
+
"metadata": {},
|
402 |
+
"output_type": "display_data"
|
403 |
+
},
|
404 |
+
{
|
405 |
+
"data": {
|
406 |
+
"text/plain": "Creating parquet from Arrow format: 0%| | 0/74 [00:00<?, ?ba/s]",
|
407 |
+
"application/vnd.jupyter.widget-view+json": {
|
408 |
+
"version_major": 2,
|
409 |
+
"version_minor": 0,
|
410 |
+
"model_id": "9bb3cbdfa4e84b96a68929fc3326536d"
|
411 |
+
}
|
412 |
+
},
|
413 |
+
"metadata": {},
|
414 |
+
"output_type": "display_data"
|
415 |
+
},
|
416 |
+
{
|
417 |
+
"data": {
|
418 |
+
"text/plain": "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]",
|
419 |
+
"application/vnd.jupyter.widget-view+json": {
|
420 |
+
"version_major": 2,
|
421 |
+
"version_minor": 0,
|
422 |
+
"model_id": "66aa46f327264d7aa8f42f4a1bcf0775"
|
423 |
+
}
|
424 |
+
},
|
425 |
+
"metadata": {},
|
426 |
+
"output_type": "display_data"
|
427 |
+
},
|
428 |
+
{
|
429 |
+
"data": {
|
430 |
+
"text/plain": "Creating parquet from Arrow format: 0%| | 0/16 [00:00<?, ?ba/s]",
|
431 |
+
"application/vnd.jupyter.widget-view+json": {
|
432 |
+
"version_major": 2,
|
433 |
+
"version_minor": 0,
|
434 |
+
"model_id": "b14c57836adc4a3692f9594acc164ff0"
|
435 |
+
}
|
436 |
+
},
|
437 |
+
"metadata": {},
|
438 |
+
"output_type": "display_data"
|
439 |
+
},
|
440 |
+
{
|
441 |
+
"data": {
|
442 |
+
"text/plain": "Uploading the dataset shards: 0%| | 0/1 [00:00<?, ?it/s]",
|
443 |
+
"application/vnd.jupyter.widget-view+json": {
|
444 |
+
"version_major": 2,
|
445 |
+
"version_minor": 0,
|
446 |
+
"model_id": "d395ca6f2c9b4ee5bb49dbce3a9bd064"
|
447 |
+
}
|
448 |
+
},
|
449 |
+
"metadata": {},
|
450 |
+
"output_type": "display_data"
|
451 |
+
},
|
452 |
+
{
|
453 |
+
"data": {
|
454 |
+
"text/plain": "Creating parquet from Arrow format: 0%| | 0/16 [00:00<?, ?ba/s]",
|
455 |
+
"application/vnd.jupyter.widget-view+json": {
|
456 |
+
"version_major": 2,
|
457 |
+
"version_minor": 0,
|
458 |
+
"model_id": "71780ee50ab649338bfa217f1767cca7"
|
459 |
+
}
|
460 |
+
},
|
461 |
+
"metadata": {},
|
462 |
+
"output_type": "display_data"
|
463 |
+
},
|
464 |
+
{
|
465 |
+
"data": {
|
466 |
+
"text/plain": "README.md: 0%| | 0.00/94.0 [00:00<?, ?B/s]",
|
467 |
+
"application/vnd.jupyter.widget-view+json": {
|
468 |
+
"version_major": 2,
|
469 |
+
"version_minor": 0,
|
470 |
+
"model_id": "1983f75eccf044649ab6423cad68dfdc"
|
471 |
+
}
|
472 |
+
},
|
473 |
+
"metadata": {},
|
474 |
+
"output_type": "display_data"
|
475 |
+
},
|
476 |
+
{
|
477 |
+
"data": {
|
478 |
+
"text/plain": "CommitInfo(commit_url='https://huggingface.co/datasets/chrisvoncsefalvay/vaers-outcomes/commit/65fa5129a0b1eb64f8fdd1aca5490965810e4ddb', commit_message='Data set commit of 105238 records of VAERS data at 2024-01-27T15:28:40.206686.', commit_description='', oid='65fa5129a0b1eb64f8fdd1aca5490965810e4ddb', pr_url='https://huggingface.co/datasets/chrisvoncsefalvay/vaers-outcomes/discussions/1', pr_revision='refs/pr/1', pr_num=1)"
|
479 |
+
},
|
480 |
+
"execution_count": 27,
|
481 |
+
"metadata": {},
|
482 |
+
"output_type": "execute_result"
|
483 |
+
}
|
484 |
+
],
|
485 |
+
"source": [
|
486 |
+
"commit_message = f\"\"\"Data set commit of {len(train) + len(test) + len(val)} records of VAERS data at {datetime.now().isoformat()}.\"\"\"\n",
|
487 |
+
"\n",
|
488 |
+
"ds.push_to_hub(HF_URL, \n",
|
489 |
+
" commit_message=commit_message,\n",
|
490 |
+
" create_pr=True)"
|
491 |
+
],
|
492 |
+
"metadata": {
|
493 |
+
"collapsed": false,
|
494 |
+
"ExecuteTime": {
|
495 |
+
"end_time": "2024-01-27T22:28:45.264233Z",
|
496 |
+
"start_time": "2024-01-27T22:28:40.207690Z"
|
497 |
+
}
|
498 |
+
},
|
499 |
+
"id": "104ffca720a27624",
|
500 |
+
"execution_count": 27
|
501 |
+
}
|
502 |
+
],
|
503 |
+
"metadata": {
|
504 |
+
"kernelspec": {
|
505 |
+
"display_name": "Python 3",
|
506 |
+
"language": "python",
|
507 |
+
"name": "python3"
|
508 |
+
},
|
509 |
+
"language_info": {
|
510 |
+
"codemirror_mode": {
|
511 |
+
"name": "ipython",
|
512 |
+
"version": 2
|
513 |
+
},
|
514 |
+
"file_extension": ".py",
|
515 |
+
"mimetype": "text/x-python",
|
516 |
+
"name": "python",
|
517 |
+
"nbconvert_exporter": "python",
|
518 |
+
"pygments_lexer": "ipython2",
|
519 |
+
"version": "2.7.6"
|
520 |
+
}
|
521 |
+
},
|
522 |
+
"nbformat": 4,
|
523 |
+
"nbformat_minor": 5
|
524 |
+
}
|
notebooks/Untitled.ipynb
ADDED
@@ -0,0 +1,33 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "code",
|
5 |
+
"execution_count": null,
|
6 |
+
"id": "8e9c5e9c-af14-4148-86bb-b04f18e4d13e",
|
7 |
+
"metadata": {},
|
8 |
+
"outputs": [],
|
9 |
+
"source": []
|
10 |
+
}
|
11 |
+
],
|
12 |
+
"metadata": {
|
13 |
+
"kernelspec": {
|
14 |
+
"display_name": "Python 3.8 - Pytorch and Tensorflow",
|
15 |
+
"language": "python",
|
16 |
+
"name": "python38-azureml-pt-tf"
|
17 |
+
},
|
18 |
+
"language_info": {
|
19 |
+
"codemirror_mode": {
|
20 |
+
"name": "ipython",
|
21 |
+
"version": 3
|
22 |
+
},
|
23 |
+
"file_extension": ".py",
|
24 |
+
"mimetype": "text/x-python",
|
25 |
+
"name": "python",
|
26 |
+
"nbconvert_exporter": "python",
|
27 |
+
"pygments_lexer": "ipython3",
|
28 |
+
"version": "3.8.5"
|
29 |
+
}
|
30 |
+
},
|
31 |
+
"nbformat": 4,
|
32 |
+
"nbformat_minor": 5
|
33 |
+
}
|
notebooks/comparisons.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:f915ebb630ffd80319041ec728d9c7123b821d8f96b4745e909e937213832d21
|
3 |
+
size 11079466
|
notebooks/daedra.ipynb.amltmp
ADDED
@@ -0,0 +1,671 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
{
|
2 |
+
"cells": [
|
3 |
+
{
|
4 |
+
"cell_type": "markdown",
|
5 |
+
"source": [
|
6 |
+
"# DAEDRA: Determining Adverse Event Disposition for Regulatory Affairs\n",
|
7 |
+
"\n",
|
8 |
+
"DAEDRA is a language model intended to predict the disposition (outcome) of an adverse event based on the text of the event report. Intended to be used to classify reports in passive reporting systems, it is trained on the [VAERS](https://vaers.hhs.gov/) dataset, which contains reports of adverse events following vaccination in the United States."
|
9 |
+
],
|
10 |
+
"metadata": {}
|
11 |
+
},
|
12 |
+
{
|
13 |
+
"cell_type": "code",
|
14 |
+
"source": [
|
15 |
+
"%pip install accelerate -U"
|
16 |
+
],
|
17 |
+
"outputs": [
|
18 |
+
{
|
19 |
+
"output_type": "stream",
|
20 |
+
"name": "stdout",
|
21 |
+
"text": "Requirement already satisfied: accelerate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.26.1)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.23.5)\nRequirement already satisfied: huggingface-hub in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.20.3)\nRequirement already satisfied: pyyaml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (6.0)\nRequirement already satisfied: psutil in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (5.9.5)\nRequirement already satisfied: torch>=1.10.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (1.12.0)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from accelerate) (0.4.2)\nRequirement already satisfied: typing_extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from torch>=1.10.0->accelerate) (4.6.3)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2.31.0)\nRequirement already satisfied: tqdm>=4.42.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (4.65.0)\nRequirement already satisfied: fsspec>=2023.5.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (2023.10.0)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from huggingface-hub->accelerate) (3.13.1)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (2023.5.7)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (1.26.16)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.4)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->huggingface-hub->accelerate) (3.1.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
22 |
+
}
|
23 |
+
],
|
24 |
+
"execution_count": 1,
|
25 |
+
"metadata": {
|
26 |
+
"gather": {
|
27 |
+
"logged": 1706475754655
|
28 |
+
},
|
29 |
+
"nteract": {
|
30 |
+
"transient": {
|
31 |
+
"deleting": false
|
32 |
+
}
|
33 |
+
},
|
34 |
+
"tags": []
|
35 |
+
}
|
36 |
+
},
|
37 |
+
{
|
38 |
+
"cell_type": "code",
|
39 |
+
"source": [
|
40 |
+
"%pip install transformers datasets shap watermark wandb evaluate codecarbon"
|
41 |
+
],
|
42 |
+
"outputs": [
|
43 |
+
{
|
44 |
+
"output_type": "stream",
|
45 |
+
"name": "stdout",
|
46 |
+
"text": "Requirement already satisfied: transformers in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (4.37.1)\nRequirement already satisfied: datasets in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.16.1)\nRequirement already satisfied: shap in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.44.1)\nRequirement already satisfied: watermark in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.4.3)\nRequirement already satisfied: wandb in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.16.2)\nRequirement already satisfied: evaluate in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (0.4.1)\nRequirement already satisfied: codecarbon in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (2.3.3)\nRequirement already satisfied: regex!=2019.12.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2023.12.25)\nRequirement already satisfied: packaging>=20.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (23.1)\nRequirement already satisfied: safetensors>=0.3.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.4.2)\nRequirement already satisfied: pyyaml>=5.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (6.0)\nRequirement already satisfied: numpy>=1.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (1.23.5)\nRequirement already satisfied: filelock in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (3.13.1)\nRequirement already satisfied: requests in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (2.31.0)\nRequirement already satisfied: huggingface-hub<1.0,>=0.19.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.20.3)\nRequirement already satisfied: tqdm>=4.27 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (4.65.0)\nRequirement already satisfied: tokenizers<0.19,>=0.14 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from transformers) (0.15.1)\nRequirement already satisfied: pyarrow>=8.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (9.0.0)\nRequirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2023.10.0)\nRequirement already satisfied: pyarrow-hotfix in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.6)\nRequirement already satisfied: dill<0.3.8,>=0.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.3.7)\nRequirement already satisfied: aiohttp in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.9.1)\nRequirement already satisfied: multiprocess in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (0.70.15)\nRequirement already satisfied: xxhash in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (3.4.1)\nRequirement already satisfied: pandas in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from datasets) (2.0.2)\nRequirement already satisfied: scipy in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.10.1)\nRequirement already satisfied: slicer==0.0.7 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.0.7)\nRequirement already satisfied: cloudpickle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (2.2.1)\nRequirement already satisfied: scikit-learn in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (1.2.2)\nRequirement already satisfied: numba in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from shap) (0.58.1)\nRequirement already satisfied: setuptools in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (65.6.3)\nRequirement already satisfied: ipython>=6.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (8.12.2)\nRequirement already satisfied: importlib-metadata>=1.4 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from watermark) (6.7.0)\nRequirement already satisfied: Click!=8.0.0,>=7.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (8.1.3)\nRequirement already satisfied: typing-extensions in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (4.6.3)\nRequirement already satisfied: protobuf!=4.21.0,<5,>=3.12.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.19.6)\nRequirement already satisfied: setproctitle in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.3.3)\nRequirement already satisfied: psutil>=5.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (5.9.5)\nRequirement already satisfied: sentry-sdk>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.39.2)\nRequirement already satisfied: docker-pycreds>=0.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (0.4.0)\nRequirement already satisfied: appdirs>=1.4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (1.4.4)\nRequirement already satisfied: GitPython!=3.1.29,>=1.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from wandb) (3.1.31)\nRequirement already satisfied: responses<0.19 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from evaluate) (0.18.0)\nRequirement already satisfied: prometheus-client in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (0.19.0)\nRequirement already satisfied: py-cpuinfo in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (9.0.0)\nRequirement already satisfied: rapidfuzz in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (3.6.1)\nRequirement already satisfied: pynvml in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (11.5.0)\nRequirement already satisfied: arrow in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from codecarbon) (1.3.0)\nRequirement already satisfied: six>=1.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\nRequirement already satisfied: yarl<2.0,>=1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.9.4)\nRequirement already satisfied: async-timeout<5.0,>=4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (4.0.3)\nRequirement already satisfied: frozenlist>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.4.1)\nRequirement already satisfied: multidict<7.0,>=4.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (6.0.4)\nRequirement already satisfied: aiosignal>=1.1.2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (1.3.1)\nRequirement already satisfied: attrs>=17.3.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from aiohttp->datasets) (23.1.0)\nRequirement already satisfied: gitdb<5,>=4.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from GitPython!=3.1.29,>=1.0.0->wandb) (4.0.10)\nRequirement already satisfied: zipp>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from importlib-metadata>=1.4->watermark) (3.15.0)\nRequirement already satisfied: pickleshare in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.7.5)\nRequirement already satisfied: backcall in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.2.0)\nRequirement already satisfied: stack-data in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.6.2)\nRequirement already satisfied: prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (3.0.30)\nRequirement already satisfied: matplotlib-inline in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.1.6)\nRequirement already satisfied: decorator in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.1.1)\nRequirement already satisfied: traitlets>=5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (5.9.0)\nRequirement already satisfied: jedi>=0.16 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (0.18.2)\nRequirement already satisfied: pygments>=2.4.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (2.15.1)\nRequirement already satisfied: pexpect>4.3 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from ipython>=6.0->watermark) (4.8.0)\nRequirement already satisfied: urllib3<3,>=1.21.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (1.26.16)\nRequirement already satisfied: charset-normalizer<4,>=2 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.1.0)\nRequirement already satisfied: idna<4,>=2.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (3.4)\nRequirement already satisfied: certifi>=2017.4.17 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from requests->transformers) (2023.5.7)\nRequirement already satisfied: types-python-dateutil>=2.8.10 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.19.20240106)\nRequirement already satisfied: python-dateutil>=2.7.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from arrow->codecarbon) (2.8.2)\nRequirement already satisfied: llvmlite<0.42,>=0.41.0dev0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from numba->shap) (0.41.1)\nRequirement already satisfied: pytz>=2020.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: tzdata>=2022.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pandas->datasets) (2023.3)\nRequirement already satisfied: threadpoolctl>=2.0.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (3.1.0)\nRequirement already satisfied: joblib>=1.1.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from scikit-learn->shap) (1.2.0)\nRequirement already satisfied: smmap<6,>=3.0.1 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb) (5.0.0)\nRequirement already satisfied: parso<0.9.0,>=0.8.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from jedi>=0.16->ipython>=6.0->watermark) (0.8.3)\nRequirement already satisfied: ptyprocess>=0.5 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from pexpect>4.3->ipython>=6.0->watermark) (0.7.0)\nRequirement already satisfied: wcwidth in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from prompt-toolkit!=3.0.37,<3.1.0,>=3.0.30->ipython>=6.0->watermark) (0.2.6)\nRequirement already satisfied: asttokens>=2.1.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (2.2.1)\nRequirement already satisfied: pure-eval in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (0.2.2)\nRequirement already satisfied: executing>=1.2.0 in /anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages (from stack-data->ipython>=6.0->watermark) (1.2.0)\nNote: you may need to restart the kernel to use updated packages.\n"
|
47 |
+
}
|
48 |
+
],
|
49 |
+
"execution_count": 2,
|
50 |
+
"metadata": {
|
51 |
+
"nteract": {
|
52 |
+
"transient": {
|
53 |
+
"deleting": false
|
54 |
+
}
|
55 |
+
}
|
56 |
+
}
|
57 |
+
},
|
58 |
+
{
|
59 |
+
"cell_type": "code",
|
60 |
+
"source": [
|
61 |
+
"import pandas as pd\n",
|
62 |
+
"import numpy as np\n",
|
63 |
+
"import torch\n",
|
64 |
+
"import os\n",
|
65 |
+
"from typing import List, Union\n",
|
66 |
+
"from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel\n",
|
67 |
+
"from datasets import load_dataset, Dataset, DatasetDict\n",
|
68 |
+
"import shap\n",
|
69 |
+
"import wandb\n",
|
70 |
+
"import evaluate\n",
|
71 |
+
"import logging\n",
|
72 |
+
"\n",
|
73 |
+
"wandb.finish()\n",
|
74 |
+
"\n",
|
75 |
+
"\n",
|
76 |
+
"os.environ[\"TOKENIZERS_PARALLELISM\"] = \"false\"\n",
|
77 |
+
"\n",
|
78 |
+
"%load_ext watermark"
|
79 |
+
],
|
80 |
+
"outputs": [
|
81 |
+
{
|
82 |
+
"output_type": "stream",
|
83 |
+
"name": "stderr",
|
84 |
+
"text": "/anaconda/envs/azureml_py38_PT_TF/lib/python3.8/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n from .autonotebook import tqdm as notebook_tqdm\n2024-01-29 17:46:15.020290: I tensorflow/core/platform/cpu_feature_guard.cc:193] This TensorFlow binary is optimized with oneAPI Deep Neural Network Library (oneDNN) to use the following CPU instructions in performance-critical operations: AVX2 FMA\nTo enable them in other operations, rebuild TensorFlow with the appropriate compiler flags.\n2024-01-29 17:46:16.031641: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer.so.7'; dlerror: libnvinfer.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031779: W tensorflow/compiler/xla/stream_executor/platform/default/dso_loader.cc:64] Could not load dynamic library 'libnvinfer_plugin.so.7'; dlerror: libnvinfer_plugin.so.7: cannot open shared object file: No such file or directory\n2024-01-29 17:46:16.031793: W tensorflow/compiler/tf2tensorrt/utils/py_utils.cc:38] TF-TRT Warning: Cannot dlopen some TensorRT libraries. If you would like to use Nvidia GPU with TensorRT, please make sure the missing libraries mentioned above are installed properly.\n"
|
85 |
+
}
|
86 |
+
],
|
87 |
+
"execution_count": 3,
|
88 |
+
"metadata": {
|
89 |
+
"datalore": {
|
90 |
+
"hide_input_from_viewers": false,
|
91 |
+
"hide_output_from_viewers": false,
|
92 |
+
"node_id": "caZjjFP0OyQNMVgZDiwswE",
|
93 |
+
"report_properties": {
|
94 |
+
"rowId": "un8W7ez7ZwoGb5Co6nydEV"
|
95 |
+
},
|
96 |
+
"type": "CODE"
|
97 |
+
},
|
98 |
+
"gather": {
|
99 |
+
"logged": 1706550378660
|
100 |
+
},
|
101 |
+
"tags": []
|
102 |
+
}
|
103 |
+
},
|
104 |
+
{
|
105 |
+
"cell_type": "code",
|
106 |
+
"source": [
|
107 |
+
"device: str = 'cuda' if torch.cuda.is_available() else 'cpu'\n",
|
108 |
+
"\n",
|
109 |
+
"SEED: int = 42\n",
|
110 |
+
"\n",
|
111 |
+
"BATCH_SIZE: int = 32\n",
|
112 |
+
"EPOCHS: int = 5\n",
|
113 |
+
"model_ckpt: str = \"distilbert-base-uncased\"\n",
|
114 |
+
"\n",
|
115 |
+
"# WandB configuration\n",
|
116 |
+
"os.environ[\"WANDB_PROJECT\"] = \"DAEDRA multiclass model training\" \n",
|
117 |
+
"os.environ[\"WANDB_LOG_MODEL\"] = \"checkpoint\" # log all model checkpoints\n",
|
118 |
+
"os.environ[\"WANDB_NOTEBOOK_NAME\"] = \"DAEDRA.ipynb\""
|
119 |
+
],
|
120 |
+
"outputs": [],
|
121 |
+
"execution_count": 4,
|
122 |
+
"metadata": {
|
123 |
+
"collapsed": false,
|
124 |
+
"gather": {
|
125 |
+
"logged": 1706550378812
|
126 |
+
},
|
127 |
+
"jupyter": {
|
128 |
+
"outputs_hidden": false
|
129 |
+
}
|
130 |
+
}
|
131 |
+
},
|
132 |
+
{
|
133 |
+
"cell_type": "code",
|
134 |
+
"source": [
|
135 |
+
"%watermark --iversion"
|
136 |
+
],
|
137 |
+
"outputs": [
|
138 |
+
{
|
139 |
+
"output_type": "stream",
|
140 |
+
"name": "stdout",
|
141 |
+
"text": "shap : 0.44.1\npandas : 2.0.2\nwandb : 0.16.2\nre : 2.2.1\nevaluate: 0.4.1\ntorch : 1.12.0\nnumpy : 1.23.5\nlogging : 0.5.1.2\n\n"
|
142 |
+
}
|
143 |
+
],
|
144 |
+
"execution_count": 5,
|
145 |
+
"metadata": {
|
146 |
+
"collapsed": false,
|
147 |
+
"jupyter": {
|
148 |
+
"outputs_hidden": false
|
149 |
+
}
|
150 |
+
}
|
151 |
+
},
|
152 |
+
{
|
153 |
+
"cell_type": "code",
|
154 |
+
"source": [
|
155 |
+
"!nvidia-smi"
|
156 |
+
],
|
157 |
+
"outputs": [
|
158 |
+
{
|
159 |
+
"output_type": "stream",
|
160 |
+
"name": "stdout",
|
161 |
+
"text": "Mon Jan 29 17:46:18 2024 \r\n+---------------------------------------------------------------------------------------+\r\n| NVIDIA-SMI 535.129.03 Driver Version: 535.129.03 CUDA Version: 12.2 |\r\n|-----------------------------------------+----------------------+----------------------+\r\n| GPU Name Persistence-M | Bus-Id Disp.A | Volatile Uncorr. ECC |\r\n| Fan Temp Perf Pwr:Usage/Cap | Memory-Usage | GPU-Util Compute M. |\r\n| | | MIG M. |\r\n|=========================================+======================+======================|\r\n| 0 Tesla V100-PCIE-16GB Off | 00000001:00:00.0 Off | Off |\r\n| N/A 26C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 1 Tesla V100-PCIE-16GB Off | 00000002:00:00.0 Off | Off |\r\n| N/A 25C P0 23W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 2 Tesla V100-PCIE-16GB Off | 00000003:00:00.0 Off | Off |\r\n| N/A 25C P0 25W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n| 3 Tesla V100-PCIE-16GB Off | 00000004:00:00.0 Off | Off |\r\n| N/A 27C P0 24W / 250W | 4MiB / 16384MiB | 0% Default |\r\n| | | N/A |\r\n+-----------------------------------------+----------------------+----------------------+\r\n \r\n+---------------------------------------------------------------------------------------+\r\n| Processes: |\r\n| GPU GI CI PID Type Process name GPU Memory |\r\n| ID ID Usage |\r\n|=======================================================================================|\r\n| No running processes found |\r\n+---------------------------------------------------------------------------------------+\r\n"
|
162 |
+
}
|
163 |
+
],
|
164 |
+
"execution_count": 6,
|
165 |
+
"metadata": {
|
166 |
+
"datalore": {
|
167 |
+
"hide_input_from_viewers": true,
|
168 |
+
"hide_output_from_viewers": true,
|
169 |
+
"node_id": "UU2oOJhwbIualogG1YyCMd",
|
170 |
+
"type": "CODE"
|
171 |
+
}
|
172 |
+
}
|
173 |
+
},
|
174 |
+
{
|
175 |
+
"cell_type": "markdown",
|
176 |
+
"source": [
|
177 |
+
"## Loading the data set"
|
178 |
+
],
|
179 |
+
"metadata": {
|
180 |
+
"datalore": {
|
181 |
+
"hide_input_from_viewers": false,
|
182 |
+
"hide_output_from_viewers": false,
|
183 |
+
"node_id": "t45KHugmcPVaO0nuk8tGJ9",
|
184 |
+
"report_properties": {
|
185 |
+
"rowId": "40nN9Hvgi1clHNV5RAemI5"
|
186 |
+
},
|
187 |
+
"type": "MD"
|
188 |
+
}
|
189 |
+
}
|
190 |
+
},
|
191 |
+
{
|
192 |
+
"cell_type": "code",
|
193 |
+
"source": [
|
194 |
+
"dataset = load_dataset(\"chrisvoncsefalvay/vaers-outcomes\")"
|
195 |
+
],
|
196 |
+
"outputs": [],
|
197 |
+
"execution_count": 7,
|
198 |
+
"metadata": {
|
199 |
+
"collapsed": false,
|
200 |
+
"gather": {
|
201 |
+
"logged": 1706550381141
|
202 |
+
},
|
203 |
+
"jupyter": {
|
204 |
+
"outputs_hidden": false
|
205 |
+
}
|
206 |
+
}
|
207 |
+
},
|
208 |
+
{
|
209 |
+
"cell_type": "code",
|
210 |
+
"source": [
|
211 |
+
"dataset"
|
212 |
+
],
|
213 |
+
"outputs": [
|
214 |
+
{
|
215 |
+
"output_type": "execute_result",
|
216 |
+
"execution_count": 8,
|
217 |
+
"data": {
|
218 |
+
"text/plain": "DatasetDict({\n train: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 1270444\n })\n test: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n val: Dataset({\n features: ['id', 'text', 'label'],\n num_rows: 272238\n })\n})"
|
219 |
+
},
|
220 |
+
"metadata": {}
|
221 |
+
}
|
222 |
+
],
|
223 |
+
"execution_count": 8,
|
224 |
+
"metadata": {
|
225 |
+
"collapsed": false,
|
226 |
+
"gather": {
|
227 |
+
"logged": 1706550381303
|
228 |
+
},
|
229 |
+
"jupyter": {
|
230 |
+
"outputs_hidden": false,
|
231 |
+
"source_hidden": false
|
232 |
+
},
|
233 |
+
"nteract": {
|
234 |
+
"transient": {
|
235 |
+
"deleting": false
|
236 |
+
}
|
237 |
+
}
|
238 |
+
}
|
239 |
+
},
|
240 |
+
{
|
241 |
+
"cell_type": "code",
|
242 |
+
"source": [
|
243 |
+
"SUBSAMPLING = 0.01\n",
|
244 |
+
"\n",
|
245 |
+
"if SUBSAMPLING < 1:\n",
|
246 |
+
" _ = DatasetDict()\n",
|
247 |
+
" for each in dataset.keys():\n",
|
248 |
+
" _[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))\n",
|
249 |
+
"\n",
|
250 |
+
" dataset = _"
|
251 |
+
],
|
252 |
+
"outputs": [],
|
253 |
+
"execution_count": 9,
|
254 |
+
"metadata": {
|
255 |
+
"gather": {
|
256 |
+
"logged": 1706550381472
|
257 |
+
}
|
258 |
+
}
|
259 |
+
},
|
260 |
+
{
|
261 |
+
"cell_type": "markdown",
|
262 |
+
"source": [
|
263 |
+
"## Tokenisation and encoding"
|
264 |
+
],
|
265 |
+
"metadata": {}
|
266 |
+
},
|
267 |
+
{
|
268 |
+
"cell_type": "code",
|
269 |
+
"source": [
|
270 |
+
"def encode_ds(ds: Union[Dataset, DatasetDict], tokenizer_model: str = model_ckpt) -> Union[Dataset, DatasetDict]:\n",
|
271 |
+
" return ds_enc"
|
272 |
+
],
|
273 |
+
"outputs": [],
|
274 |
+
"execution_count": 10,
|
275 |
+
"metadata": {
|
276 |
+
"gather": {
|
277 |
+
"logged": 1706550381637
|
278 |
+
}
|
279 |
+
}
|
280 |
+
},
|
281 |
+
{
|
282 |
+
"cell_type": "markdown",
|
283 |
+
"source": [
|
284 |
+
"## Evaluation metrics"
|
285 |
+
],
|
286 |
+
"metadata": {}
|
287 |
+
},
|
288 |
+
{
|
289 |
+
"cell_type": "code",
|
290 |
+
"source": [
|
291 |
+
"accuracy = evaluate.load(\"accuracy\")\n",
|
292 |
+
"precision, recall = evaluate.load(\"precision\"), evaluate.load(\"recall\")\n",
|
293 |
+
"f1 = evaluate.load(\"f1\")"
|
294 |
+
],
|
295 |
+
"outputs": [],
|
296 |
+
"execution_count": 11,
|
297 |
+
"metadata": {
|
298 |
+
"gather": {
|
299 |
+
"logged": 1706550381778
|
300 |
+
}
|
301 |
+
}
|
302 |
+
},
|
303 |
+
{
|
304 |
+
"cell_type": "code",
|
305 |
+
"source": [
|
306 |
+
"def compute_metrics(eval_pred):\n",
|
307 |
+
" predictions, labels = eval_pred\n",
|
308 |
+
" predictions = np.argmax(predictions, axis=1)\n",
|
309 |
+
" return {\n",
|
310 |
+
" 'accuracy': accuracy.compute(predictions=predictions, references=labels)[\"accuracy\"],\n",
|
311 |
+
" 'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')[\"precision\"],\n",
|
312 |
+
" 'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')[\"precision\"],\n",
|
313 |
+
" 'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')[\"recall\"],\n",
|
314 |
+
" 'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')[\"recall\"],\n",
|
315 |
+
" 'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')[\"f1\"]\n",
|
316 |
+
" }"
|
317 |
+
],
|
318 |
+
"outputs": [],
|
319 |
+
"execution_count": 12,
|
320 |
+
"metadata": {
|
321 |
+
"gather": {
|
322 |
+
"logged": 1706550381891
|
323 |
+
}
|
324 |
+
}
|
325 |
+
},
|
326 |
+
{
|
327 |
+
"cell_type": "markdown",
|
328 |
+
"source": [
|
329 |
+
"## Training"
|
330 |
+
],
|
331 |
+
"metadata": {}
|
332 |
+
},
|
333 |
+
{
|
334 |
+
"cell_type": "markdown",
|
335 |
+
"source": [
|
336 |
+
"We specify a label map – this has to be done manually, even if `Datasets` has a function for it, as `AutoModelForSequenceClassification` requires an object with a length :("
|
337 |
+
],
|
338 |
+
"metadata": {}
|
339 |
+
},
|
340 |
+
{
|
341 |
+
"cell_type": "code",
|
342 |
+
"source": [
|
343 |
+
"label_map = {i: label for i, label in enumerate(dataset[\"test\"].features[\"label\"].names)}"
|
344 |
+
],
|
345 |
+
"outputs": [],
|
346 |
+
"execution_count": 13,
|
347 |
+
"metadata": {
|
348 |
+
"gather": {
|
349 |
+
"logged": 1706550382032
|
350 |
+
}
|
351 |
+
}
|
352 |
+
},
|
353 |
+
{
|
354 |
+
"cell_type": "code",
|
355 |
+
"source": [
|
356 |
+
"def train_from_model(model_ckpt: str, push: bool = False):\n",
|
357 |
+
" print(f\"Initialising training based on {model_ckpt}...\")\n",
|
358 |
+
"\n",
|
359 |
+
" print(\"Tokenising...\")\n",
|
360 |
+
" tokenizer = AutoTokenizer.from_pretrained(model_ckpt)\n",
|
361 |
+
"\n",
|
362 |
+
" cols = dataset[\"train\"].column_names\n",
|
363 |
+
" cols.remove(\"label\")\n",
|
364 |
+
" ds_enc = dataset.map(lambda x: tokenizer(x[\"text\"], truncation=True, max_length=512), batched=True, remove_columns=cols)\n",
|
365 |
+
"\n",
|
366 |
+
" print(\"Loading model...\")\n",
|
367 |
+
" try:\n",
|
368 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
369 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
370 |
+
" id2label=label_map, \n",
|
371 |
+
" label2id={v:k for k,v in label_map.items()})\n",
|
372 |
+
" except OSError:\n",
|
373 |
+
" model = AutoModelForSequenceClassification.from_pretrained(model_ckpt, \n",
|
374 |
+
" num_labels=len(dataset[\"test\"].features[\"label\"].names), \n",
|
375 |
+
" id2label=label_map, \n",
|
376 |
+
" label2id={v:k for k,v in label_map.items()},\n",
|
377 |
+
" from_tf=True)\n",
|
378 |
+
"\n",
|
379 |
+
"\n",
|
380 |
+
" args = TrainingArguments(\n",
|
381 |
+
" output_dir=\"vaers\",\n",
|
382 |
+
" evaluation_strategy=\"epoch\",\n",
|
383 |
+
" save_strategy=\"epoch\",\n",
|
384 |
+
" learning_rate=2e-5,\n",
|
385 |
+
" per_device_train_batch_size=BATCH_SIZE,\n",
|
386 |
+
" per_device_eval_batch_size=BATCH_SIZE,\n",
|
387 |
+
" num_train_epochs=EPOCHS,\n",
|
388 |
+
" weight_decay=.01,\n",
|
389 |
+
" logging_steps=1,\n",
|
390 |
+
" load_best_model_at_end=True,\n",
|
391 |
+
" run_name=f\"daedra-training\",\n",
|
392 |
+
" report_to=[\"wandb\"])\n",
|
393 |
+
"\n",
|
394 |
+
" trainer = Trainer(\n",
|
395 |
+
" model=model,\n",
|
396 |
+
" args=args,\n",
|
397 |
+
" train_dataset=ds_enc[\"train\"],\n",
|
398 |
+
" eval_dataset=ds_enc[\"test\"],\n",
|
399 |
+
" tokenizer=tokenizer,\n",
|
400 |
+
" compute_metrics=compute_metrics)\n",
|
401 |
+
" \n",
|
402 |
+
" if SUBSAMPLING != 1.0:\n",
|
403 |
+
" wandb_tag: List[str] = [f\"subsample-{SUBSAMPLING}\"]\n",
|
404 |
+
" else:\n",
|
405 |
+
" wandb_tag: List[str] = [f\"full_sample\"]\n",
|
406 |
+
"\n",
|
407 |
+
" wandb_tag.append(f\"batch_size-{BATCH_SIZE}\")\n",
|
408 |
+
" wandb_tag.append(f\"base:{model_ckpt}\")\n",
|
409 |
+
" \n",
|
410 |
+
" wandb.init(name=f\"daedra_{SUBSAMPLING}-{model_ckpt}\", tags=wandb_tag, magic=True)\n",
|
411 |
+
"\n",
|
412 |
+
" print(\"Starting training...\")\n",
|
413 |
+
"\n",
|
414 |
+
" trainer.train()\n",
|
415 |
+
"\n",
|
416 |
+
" print(\"Training finished.\")\n",
|
417 |
+
"\n",
|
418 |
+
" if push:\n",
|
419 |
+
" variant = \"full_sample\" if SUBSAMPLING == 1.0 else f\"subsample-{SUBSAMPLING}\"\n",
|
420 |
+
" tokenizer._tokenizer.save(\"tokenizer.json\")\n",
|
421 |
+
" tokenizer.push_to_hub(\"chrisvoncsefalvay/daedra\")\n",
|
422 |
+
" sample = \"full sample\" if SUBSAMPLING == 1.0 else f\"{SUBSAMPLING * 100}% of the full sample\"\n",
|
423 |
+
"\n",
|
424 |
+
" model.push_to_hub(\"chrisvoncsefalvay/daedra\", \n",
|
425 |
+
" variant=variant,\n",
|
426 |
+
" commit_message=f\"DAEDRA model trained on {sample} of the VAERS dataset (training set size: {dataset['train'].num_rows:,}), based on {model_ckpt}\")"
|
427 |
+
],
|
428 |
+
"outputs": [],
|
429 |
+
"execution_count": 14,
|
430 |
+
"metadata": {
|
431 |
+
"jupyter": {
|
432 |
+
"outputs_hidden": false,
|
433 |
+
"source_hidden": false
|
434 |
+
},
|
435 |
+
"nteract": {
|
436 |
+
"transient": {
|
437 |
+
"deleting": false
|
438 |
+
}
|
439 |
+
},
|
440 |
+
"gather": {
|
441 |
+
"logged": 1706550382160
|
442 |
+
}
|
443 |
+
}
|
444 |
+
},
|
445 |
+
{
|
446 |
+
"cell_type": "code",
|
447 |
+
"source": [
|
448 |
+
"\n",
|
449 |
+
"base_models = [\n",
|
450 |
+
" \"bert-base-uncased\",\n",
|
451 |
+
" \"distilbert-base-uncased\",\n",
|
452 |
+
"]"
|
453 |
+
],
|
454 |
+
"outputs": [],
|
455 |
+
"execution_count": 15,
|
456 |
+
"metadata": {
|
457 |
+
"gather": {
|
458 |
+
"logged": 1706550382318
|
459 |
+
}
|
460 |
+
}
|
461 |
+
},
|
462 |
+
{
|
463 |
+
"cell_type": "code",
|
464 |
+
"source": [
|
465 |
+
"BATCH_SIZE=1\n",
|
466 |
+
"\n",
|
467 |
+
"train_from_model(\"biobert/Bio_ClinicalBERT/\")"
|
468 |
+
],
|
469 |
+
"outputs": [
|
470 |
+
{
|
471 |
+
"output_type": "stream",
|
472 |
+
"name": "stdout",
|
473 |
+
"text": "Initialising training based on biobert/Bio_ClinicalBERT/...\nTokenising...\nLoading model...\n"
|
474 |
+
},
|
475 |
+
{
|
476 |
+
"output_type": "stream",
|
477 |
+
"name": "stderr",
|
478 |
+
"text": "Map: 100%|██████████| 2722/2722 [00:01<00:00, 2195.12 examples/s]\nAll TF 2.0 model weights were used when initializing BertForSequenceClassification.\n\nAll the weights of BertForSequenceClassification were initialized from the TF 2.0 model.\nIf your task is similar to the task the model of the checkpoint was trained on, you can already use BertForSequenceClassification for predictions without further training.\n"
|
479 |
+
},
|
480 |
+
{
|
481 |
+
"output_type": "display_data",
|
482 |
+
"data": {
|
483 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
484 |
+
"text/html": "Finishing last run (ID:sg022tqh) before initializing another..."
|
485 |
+
},
|
486 |
+
"metadata": {}
|
487 |
+
},
|
488 |
+
{
|
489 |
+
"output_type": "display_data",
|
490 |
+
"data": {
|
491 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
492 |
+
"text/html": " View run <strong style=\"color:#cdcd00\">daedra_0.01-biobert/Bio_ClinicalBERT/</strong> at: <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/sg022tqh</a><br/> View job at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/jobs/QXJ0aWZhY3RDb2xsZWN0aW9uOjEzNDcyMTQwMw==/version_details/v6</a><br/>Synced 6 W&B file(s), 0 media file(s), 0 artifact file(s) and 0 other file(s)"
|
493 |
+
},
|
494 |
+
"metadata": {}
|
495 |
+
},
|
496 |
+
{
|
497 |
+
"output_type": "display_data",
|
498 |
+
"data": {
|
499 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
500 |
+
"text/html": "Find logs at: <code>./wandb/run-20240129_174816-sg022tqh/logs</code>"
|
501 |
+
},
|
502 |
+
"metadata": {}
|
503 |
+
},
|
504 |
+
{
|
505 |
+
"output_type": "display_data",
|
506 |
+
"data": {
|
507 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
508 |
+
"text/html": "Successfully finished last run (ID:sg022tqh). Initializing new run:<br/>"
|
509 |
+
},
|
510 |
+
"metadata": {}
|
511 |
+
},
|
512 |
+
{
|
513 |
+
"output_type": "display_data",
|
514 |
+
"data": {
|
515 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
516 |
+
"text/html": "Tracking run with wandb version 0.16.2"
|
517 |
+
},
|
518 |
+
"metadata": {}
|
519 |
+
},
|
520 |
+
{
|
521 |
+
"output_type": "display_data",
|
522 |
+
"data": {
|
523 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
524 |
+
"text/html": "Run data is saved locally in <code>/mnt/batch/tasks/shared/LS_root/mounts/clusters/daedra-hptrain-cvc/code/Users/kristof.csefalvay/daedra/notebooks/wandb/run-20240129_174936-kilkkg1j</code>"
|
525 |
+
},
|
526 |
+
"metadata": {}
|
527 |
+
},
|
528 |
+
{
|
529 |
+
"output_type": "display_data",
|
530 |
+
"data": {
|
531 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
532 |
+
"text/html": "Syncing run <strong><a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">daedra_0.01-biobert/Bio_ClinicalBERT/</a></strong> to <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">Weights & Biases</a> (<a href='https://wandb.me/run' target=\"_blank\">docs</a>)<br/>"
|
533 |
+
},
|
534 |
+
"metadata": {}
|
535 |
+
},
|
536 |
+
{
|
537 |
+
"output_type": "display_data",
|
538 |
+
"data": {
|
539 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
540 |
+
"text/html": " View project at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training</a>"
|
541 |
+
},
|
542 |
+
"metadata": {}
|
543 |
+
},
|
544 |
+
{
|
545 |
+
"output_type": "display_data",
|
546 |
+
"data": {
|
547 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
548 |
+
"text/html": " View run at <a href='https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j' target=\"_blank\">https://wandb.ai/chrisvoncsefalvay/DAEDRA%20multiclass%20model%20training/runs/kilkkg1j</a>"
|
549 |
+
},
|
550 |
+
"metadata": {}
|
551 |
+
},
|
552 |
+
{
|
553 |
+
"output_type": "stream",
|
554 |
+
"name": "stdout",
|
555 |
+
"text": "Starting training...\n"
|
556 |
+
},
|
557 |
+
{
|
558 |
+
"output_type": "stream",
|
559 |
+
"name": "stderr",
|
560 |
+
"text": "Was asked to gather along dimension 0, but all input tensors were scalars; will instead unsqueeze and return a vector.\n"
|
561 |
+
},
|
562 |
+
{
|
563 |
+
"output_type": "display_data",
|
564 |
+
"data": {
|
565 |
+
"text/plain": "<IPython.core.display.HTML object>",
|
566 |
+
"text/html": "\n <div>\n \n <progress value='1496' max='15880' style='width:300px; height:20px; vertical-align: middle;'></progress>\n [ 1496/15880 07:43 < 1:14:19, 3.23 it/s, Epoch 0.47/5]\n </div>\n <table border=\"1\" class=\"dataframe\">\n <thead>\n <tr style=\"text-align: left;\">\n <th>Epoch</th>\n <th>Training Loss</th>\n <th>Validation Loss</th>\n </tr>\n </thead>\n <tbody>\n </tbody>\n</table><p>"
|
567 |
+
},
|
568 |
+
"metadata": {}
|
569 |
+
}
|
570 |
+
],
|
571 |
+
"execution_count": 21,
|
572 |
+
"metadata": {
|
573 |
+
"gather": {
|
574 |
+
"logged": 1706551053473
|
575 |
+
}
|
576 |
+
}
|
577 |
+
},
|
578 |
+
{
|
579 |
+
"cell_type": "code",
|
580 |
+
"source": [],
|
581 |
+
"outputs": [],
|
582 |
+
"execution_count": null,
|
583 |
+
"metadata": {
|
584 |
+
"jupyter": {
|
585 |
+
"source_hidden": false,
|
586 |
+
"outputs_hidden": false
|
587 |
+
},
|
588 |
+
"nteract": {
|
589 |
+
"transient": {
|
590 |
+
"deleting": false
|
591 |
+
}
|
592 |
+
}
|
593 |
+
}
|
594 |
+
}
|
595 |
+
],
|
596 |
+
"metadata": {
|
597 |
+
"datalore": {
|
598 |
+
"base_environment": "default",
|
599 |
+
"computation_mode": "JUPYTER",
|
600 |
+
"package_manager": "pip",
|
601 |
+
"packages": [
|
602 |
+
{
|
603 |
+
"name": "datasets",
|
604 |
+
"source": "PIP",
|
605 |
+
"version": "2.16.1"
|
606 |
+
},
|
607 |
+
{
|
608 |
+
"name": "torch",
|
609 |
+
"source": "PIP",
|
610 |
+
"version": "2.1.2"
|
611 |
+
},
|
612 |
+
{
|
613 |
+
"name": "accelerate",
|
614 |
+
"source": "PIP",
|
615 |
+
"version": "0.26.1"
|
616 |
+
}
|
617 |
+
],
|
618 |
+
"report_row_ids": [
|
619 |
+
"un8W7ez7ZwoGb5Co6nydEV",
|
620 |
+
"40nN9Hvgi1clHNV5RAemI5",
|
621 |
+
"TgRD90H5NSPpKS41OeXI1w",
|
622 |
+
"ZOm5BfUs3h1EGLaUkBGeEB",
|
623 |
+
"kOP0CZWNSk6vqE3wkPp7Vc",
|
624 |
+
"W4PWcOu2O2pRaZyoE2W80h",
|
625 |
+
"RolbOnQLIftk0vy9mIcz5M",
|
626 |
+
"8OPhUgbaNJmOdiq5D3a6vK",
|
627 |
+
"5Qrt3jSvSrpK6Ne1hS6shL",
|
628 |
+
"hTq7nFUrovN5Ao4u6dIYWZ",
|
629 |
+
"I8WNZLpJ1DVP2wiCW7YBIB",
|
630 |
+
"SawhU3I9BewSE1XBPstpNJ",
|
631 |
+
"80EtLEl2FIE4FqbWnUD3nT"
|
632 |
+
],
|
633 |
+
"version": 3
|
634 |
+
},
|
635 |
+
"kernel_info": {
|
636 |
+
"name": "python38-azureml-pt-tf"
|
637 |
+
},
|
638 |
+
"kernelspec": {
|
639 |
+
"display_name": "azureml_py38_PT_TF",
|
640 |
+
"language": "python",
|
641 |
+
"name": "python3"
|
642 |
+
},
|
643 |
+
"language_info": {
|
644 |
+
"name": "python",
|
645 |
+
"version": "3.8.5",
|
646 |
+
"mimetype": "text/x-python",
|
647 |
+
"codemirror_mode": {
|
648 |
+
"name": "ipython",
|
649 |
+
"version": 3
|
650 |
+
},
|
651 |
+
"pygments_lexer": "ipython3",
|
652 |
+
"nbconvert_exporter": "python",
|
653 |
+
"file_extension": ".py"
|
654 |
+
},
|
655 |
+
"microsoft": {
|
656 |
+
"host": {
|
657 |
+
"AzureML": {
|
658 |
+
"notebookHasBeenCompleted": true
|
659 |
+
}
|
660 |
+
},
|
661 |
+
"ms_spell_check": {
|
662 |
+
"ms_spell_check_language": "en"
|
663 |
+
}
|
664 |
+
},
|
665 |
+
"nteract": {
|
666 |
+
"version": "nteract-front-end@1.0.0"
|
667 |
+
}
|
668 |
+
},
|
669 |
+
"nbformat": 4,
|
670 |
+
"nbformat_minor": 4
|
671 |
+
}
|
notebooks/daedra.py
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from typing import List, Union
|
6 |
+
from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel
|
7 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
8 |
+
import shap
|
9 |
+
import wandb
|
10 |
+
import evaluate
|
11 |
+
import logging
|
12 |
+
|
13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
14 |
+
|
15 |
+
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
+
SEED: int = 42
|
18 |
+
|
19 |
+
BATCH_SIZE: int = 16
|
20 |
+
EPOCHS: int = 3
|
21 |
+
SUBSAMPLING: float = 0.1
|
22 |
+
|
23 |
+
# WandB configuration
|
24 |
+
os.environ["WANDB_PROJECT"] = "DAEDRA multiclass model training"
|
25 |
+
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
|
26 |
+
os.environ["WANDB_NOTEBOOK_NAME"] = "DAEDRA.ipynb"
|
27 |
+
|
28 |
+
dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes")
|
29 |
+
|
30 |
+
if SUBSAMPLING < 1:
|
31 |
+
_ = DatasetDict()
|
32 |
+
for each in dataset.keys():
|
33 |
+
_[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))
|
34 |
+
|
35 |
+
dataset = _
|
36 |
+
|
37 |
+
accuracy = evaluate.load("accuracy")
|
38 |
+
precision, recall = evaluate.load("precision"), evaluate.load("recall")
|
39 |
+
f1 = evaluate.load("f1")
|
40 |
+
|
41 |
+
def compute_metrics(eval_pred):
|
42 |
+
predictions, labels = eval_pred
|
43 |
+
predictions = np.argmax(predictions, axis=1)
|
44 |
+
return {
|
45 |
+
'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"],
|
46 |
+
'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"],
|
47 |
+
'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"],
|
48 |
+
'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"],
|
49 |
+
'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"],
|
50 |
+
'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"]
|
51 |
+
}
|
52 |
+
|
53 |
+
label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)}
|
54 |
+
|
55 |
+
def train_from_model(model_ckpt: str, push: bool = False):
|
56 |
+
print(f"Initialising training based on {model_ckpt}...")
|
57 |
+
|
58 |
+
print("Tokenising...")
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
60 |
+
|
61 |
+
cols = dataset["train"].column_names
|
62 |
+
cols.remove("label")
|
63 |
+
ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols)
|
64 |
+
|
65 |
+
print("Loading model...")
|
66 |
+
try:
|
67 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
|
68 |
+
num_labels=len(dataset["test"].features["label"].names),
|
69 |
+
id2label=label_map,
|
70 |
+
label2id={v:k for k,v in label_map.items()})
|
71 |
+
except OSError:
|
72 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
|
73 |
+
num_labels=len(dataset["test"].features["label"].names),
|
74 |
+
id2label=label_map,
|
75 |
+
label2id={v:k for k,v in label_map.items()},
|
76 |
+
from_tf=True)
|
77 |
+
|
78 |
+
|
79 |
+
args = TrainingArguments(
|
80 |
+
output_dir="vaers",
|
81 |
+
evaluation_strategy="steps",
|
82 |
+
eval_steps=100,
|
83 |
+
save_strategy="epoch",
|
84 |
+
learning_rate=2e-5,
|
85 |
+
per_device_train_batch_size=BATCH_SIZE,
|
86 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
87 |
+
num_train_epochs=EPOCHS,
|
88 |
+
weight_decay=.01,
|
89 |
+
logging_steps=1,
|
90 |
+
run_name=f"daedra-minisample-comparison-{SUBSAMPLING}",
|
91 |
+
report_to=["wandb"])
|
92 |
+
|
93 |
+
trainer = Trainer(
|
94 |
+
model=model,
|
95 |
+
args=args,
|
96 |
+
train_dataset=ds_enc["train"],
|
97 |
+
eval_dataset=ds_enc["test"],
|
98 |
+
tokenizer=tokenizer,
|
99 |
+
compute_metrics=compute_metrics)
|
100 |
+
|
101 |
+
if SUBSAMPLING != 1.0:
|
102 |
+
wandb_tag: List[str] = [f"subsample-{SUBSAMPLING}"]
|
103 |
+
else:
|
104 |
+
wandb_tag: List[str] = [f"full_sample"]
|
105 |
+
|
106 |
+
wandb_tag.append(f"batch_size-{BATCH_SIZE}")
|
107 |
+
wandb_tag.append(f"base:{model_ckpt}")
|
108 |
+
|
109 |
+
if "/" in model_ckpt:
|
110 |
+
sanitised_model_name = model_ckpt.split("/")[1]
|
111 |
+
else:
|
112 |
+
sanitised_model_name = model_ckpt
|
113 |
+
|
114 |
+
wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True)
|
115 |
+
|
116 |
+
print("Starting training...")
|
117 |
+
|
118 |
+
trainer.train()
|
119 |
+
|
120 |
+
print("Training finished.")
|
121 |
+
|
122 |
+
wandb.finish()
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
wandb.finish()
|
126 |
+
|
127 |
+
for mname in (
|
128 |
+
#"dmis-lab/biobert-base-cased-v1.2",
|
129 |
+
"emilyalsentzer/Bio_ClinicalBERT",
|
130 |
+
"bert-base-uncased",
|
131 |
+
"distilbert-base-uncased"
|
132 |
+
):
|
133 |
+
print(f"Now training on subsample with {mname}...")
|
134 |
+
train_from_model(mname)
|
notebooks/daedra.py.amltmp
ADDED
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from typing import List, Union
|
6 |
+
from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel
|
7 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
8 |
+
import shap
|
9 |
+
import wandb
|
10 |
+
import evaluate
|
11 |
+
import logging
|
12 |
+
|
13 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
14 |
+
|
15 |
+
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
16 |
+
|
17 |
+
SEED: int = 42
|
18 |
+
|
19 |
+
BATCH_SIZE: int = 16
|
20 |
+
EPOCHS: int = 3
|
21 |
+
SUBSAMPLING: float = 0.1
|
22 |
+
|
23 |
+
# WandB configuration
|
24 |
+
os.environ["WANDB_PROJECT"] = "DAEDRA multiclass model training"
|
25 |
+
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
|
26 |
+
os.environ["WANDB_NOTEBOOK_NAME"] = "DAEDRA.ipynb"
|
27 |
+
|
28 |
+
dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes")
|
29 |
+
|
30 |
+
if SUBSAMPLING < 1:
|
31 |
+
_ = DatasetDict()
|
32 |
+
for each in dataset.keys():
|
33 |
+
_[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))
|
34 |
+
|
35 |
+
dataset = _
|
36 |
+
|
37 |
+
accuracy = evaluate.load("accuracy")
|
38 |
+
precision, recall = evaluate.load("precision"), evaluate.load("recall")
|
39 |
+
f1 = evaluate.load("f1")
|
40 |
+
|
41 |
+
def compute_metrics(eval_pred):
|
42 |
+
predictions, labels = eval_pred
|
43 |
+
predictions = np.argmax(predictions, axis=1)
|
44 |
+
return {
|
45 |
+
'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"],
|
46 |
+
'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"],
|
47 |
+
'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"],
|
48 |
+
'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"],
|
49 |
+
'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"],
|
50 |
+
'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"]
|
51 |
+
}
|
52 |
+
|
53 |
+
label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)}
|
54 |
+
|
55 |
+
def train_from_model(model_ckpt: str, push: bool = False):
|
56 |
+
print(f"Initialising training based on {model_ckpt}...")
|
57 |
+
|
58 |
+
print("Tokenising...")
|
59 |
+
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
60 |
+
|
61 |
+
cols = dataset["train"].column_names
|
62 |
+
cols.remove("label")
|
63 |
+
ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols)
|
64 |
+
|
65 |
+
print("Loading model...")
|
66 |
+
try:
|
67 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
|
68 |
+
num_labels=len(dataset["test"].features["label"].names),
|
69 |
+
id2label=label_map,
|
70 |
+
label2id={v:k for k,v in label_map.items()})
|
71 |
+
except OSError:
|
72 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
|
73 |
+
num_labels=len(dataset["test"].features["label"].names),
|
74 |
+
id2label=label_map,
|
75 |
+
label2id={v:k for k,v in label_map.items()},
|
76 |
+
from_tf=True)
|
77 |
+
|
78 |
+
|
79 |
+
args = TrainingArguments(
|
80 |
+
output_dir="vaers",
|
81 |
+
evaluation_strategy="steps",
|
82 |
+
eval_steps=100,
|
83 |
+
save_strategy="epoch",
|
84 |
+
learning_rate=2e-5,
|
85 |
+
per_device_train_batch_size=BATCH_SIZE,
|
86 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
87 |
+
num_train_epochs=EPOCHS,
|
88 |
+
weight_decay=.01,
|
89 |
+
logging_steps=1,
|
90 |
+
run_name=f"daedra-minisample-comparison-{SUBSAMPLING}",
|
91 |
+
report_to=["wandb"])
|
92 |
+
|
93 |
+
trainer = Trainer(
|
94 |
+
model=model,
|
95 |
+
args=args,
|
96 |
+
train_dataset=ds_enc["train"],
|
97 |
+
eval_dataset=ds_enc["test"],
|
98 |
+
tokenizer=tokenizer,
|
99 |
+
compute_metrics=compute_metrics)
|
100 |
+
|
101 |
+
if SUBSAMPLING != 1.0:
|
102 |
+
wandb_tag: List[str] = [f"subsample-{SUBSAMPLING}"]
|
103 |
+
else:
|
104 |
+
wandb_tag: List[str] = [f"full_sample"]
|
105 |
+
|
106 |
+
wandb_tag.append(f"batch_size-{BATCH_SIZE}")
|
107 |
+
wandb_tag.append(f"base:{model_ckpt}")
|
108 |
+
|
109 |
+
if "/" in model_ckpt:
|
110 |
+
sanitised_model_name = model_ckpt.split("/")[1]
|
111 |
+
else:
|
112 |
+
sanitised_model_name = model_ckpt
|
113 |
+
|
114 |
+
wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True)
|
115 |
+
|
116 |
+
print("Starting training...")
|
117 |
+
|
118 |
+
trainer.train()
|
119 |
+
|
120 |
+
print("Training finished.")
|
121 |
+
|
122 |
+
wandb.finish()
|
123 |
+
|
124 |
+
if __name__ == "__main__":
|
125 |
+
wandb.finish()
|
126 |
+
|
127 |
+
for mname in (
|
128 |
+
#"dmis-lab/biobert-base-cased-v1.2",
|
129 |
+
"emilyalsentzer/Bio_ClinicalBERT",
|
130 |
+
"bert-base-uncased",
|
131 |
+
"distilbert-base-uncased"
|
132 |
+
):
|
133 |
+
print(f"Now training on subsample with {mname}...")
|
134 |
+
train_from_model(mname)
|
notebooks/daedra_final_training.py.amltmp
ADDED
@@ -0,0 +1,136 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pandas as pd
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import os
|
5 |
+
from typing import List, Union
|
6 |
+
from transformers import AutoTokenizer, Trainer, AutoModelForSequenceClassification, TrainingArguments, DataCollatorWithPadding, pipeline, AutoModel
|
7 |
+
from datasets import load_dataset, Dataset, DatasetDict
|
8 |
+
import shap
|
9 |
+
import wandb
|
10 |
+
import evaluate
|
11 |
+
import logging
|
12 |
+
from codecarbon import EmissionsTracker
|
13 |
+
|
14 |
+
|
15 |
+
tracker = EmissionsTracker()
|
16 |
+
|
17 |
+
os.environ["TOKENIZERS_PARALLELISM"] = "false"
|
18 |
+
|
19 |
+
device: str = 'cuda' if torch.cuda.is_available() else 'cpu'
|
20 |
+
|
21 |
+
SEED: int = 42
|
22 |
+
|
23 |
+
BATCH_SIZE: int = 16
|
24 |
+
EPOCHS: int = 3
|
25 |
+
SUBSAMPLING: float = 1
|
26 |
+
|
27 |
+
# WandB configuration
|
28 |
+
os.environ["WANDB_PROJECT"] = "DAEDRA final model training"
|
29 |
+
os.environ["WANDB_LOG_MODEL"] = "checkpoint" # log all model checkpoints
|
30 |
+
|
31 |
+
dataset = load_dataset("chrisvoncsefalvay/vaers-outcomes")
|
32 |
+
|
33 |
+
if SUBSAMPLING < 1:
|
34 |
+
_ = DatasetDict()
|
35 |
+
for each in dataset.keys():
|
36 |
+
_[each] = dataset[each].shuffle(seed=SEED).select(range(int(len(dataset[each]) * SUBSAMPLING)))
|
37 |
+
|
38 |
+
dataset = _
|
39 |
+
|
40 |
+
accuracy = evaluate.load("accuracy")
|
41 |
+
precision, recall = evaluate.load("precision"), evaluate.load("recall")
|
42 |
+
f1 = evaluate.load("f1")
|
43 |
+
|
44 |
+
def compute_metrics(eval_pred):
|
45 |
+
predictions, labels = eval_pred
|
46 |
+
predictions = np.argmax(predictions, axis=1)
|
47 |
+
return {
|
48 |
+
'accuracy': accuracy.compute(predictions=predictions, references=labels)["accuracy"],
|
49 |
+
'precision_macroaverage': precision.compute(predictions=predictions, references=labels, average='macro')["precision"],
|
50 |
+
'precision_microaverage': precision.compute(predictions=predictions, references=labels, average='micro')["precision"],
|
51 |
+
'recall_macroaverage': recall.compute(predictions=predictions, references=labels, average='macro')["recall"],
|
52 |
+
'recall_microaverage': recall.compute(predictions=predictions, references=labels, average='micro')["recall"],
|
53 |
+
'f1_microaverage': f1.compute(predictions=predictions, references=labels, average='micro')["f1"]
|
54 |
+
}
|
55 |
+
|
56 |
+
label_map = {i: label for i, label in enumerate(dataset["test"].features["label"].names)}
|
57 |
+
|
58 |
+
def train_from_model(model_ckpt: str, push: bool = False):
|
59 |
+
print(f"Initialising training based on {model_ckpt}...")
|
60 |
+
|
61 |
+
print("Tokenising...")
|
62 |
+
tokenizer = AutoTokenizer.from_pretrained(model_ckpt)
|
63 |
+
|
64 |
+
cols = dataset["train"].column_names
|
65 |
+
cols.remove("label")
|
66 |
+
ds_enc = dataset.map(lambda x: tokenizer(x["text"], truncation=True, max_length=512), batched=True, remove_columns=cols)
|
67 |
+
|
68 |
+
print("Loading model...")
|
69 |
+
try:
|
70 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
|
71 |
+
num_labels=len(dataset["test"].features["label"].names),
|
72 |
+
id2label=label_map,
|
73 |
+
label2id={v:k for k,v in label_map.items()})
|
74 |
+
except OSError:
|
75 |
+
model = AutoModelForSequenceClassification.from_pretrained(model_ckpt,
|
76 |
+
num_labels=len(dataset["test"].features["label"].names),
|
77 |
+
id2label=label_map,
|
78 |
+
label2id={v:k for k,v in label_map.items()},
|
79 |
+
from_tf=True)
|
80 |
+
|
81 |
+
|
82 |
+
args = TrainingArguments(
|
83 |
+
output_dir="daedra",
|
84 |
+
evaluation_strategy="steps",
|
85 |
+
eval_steps=1000,
|
86 |
+
save_steps=2000,
|
87 |
+
save_strategy="steps",
|
88 |
+
learning_rate=2e-5,
|
89 |
+
per_device_train_batch_size=BATCH_SIZE,
|
90 |
+
per_device_eval_batch_size=BATCH_SIZE,
|
91 |
+
num_train_epochs=EPOCHS,
|
92 |
+
weight_decay=.01,
|
93 |
+
logging_steps=1,
|
94 |
+
run_name=f"daedra-full-train",
|
95 |
+
report_to=["wandb", "codecarbon"],
|
96 |
+
save_total_limit=2,
|
97 |
+
load_best_model_at_end=True,
|
98 |
+
push_to_hub=True,
|
99 |
+
push_to_hub_model_id="daedra",
|
100 |
+
hub_strategy="every_save",
|
101 |
+
metric_for_best_model="f1_microaverage")
|
102 |
+
|
103 |
+
trainer = Trainer(
|
104 |
+
model=model,
|
105 |
+
args=args,
|
106 |
+
train_dataset=ds_enc["train"],
|
107 |
+
eval_dataset=ds_enc["test"],
|
108 |
+
tokenizer=tokenizer,
|
109 |
+
compute_metrics=compute_metrics)
|
110 |
+
|
111 |
+
wandb_tag: List[str] = ["full_sample"]
|
112 |
+
|
113 |
+
wandb_tag.append(f"batch_size-{BATCH_SIZE}")
|
114 |
+
wandb_tag.append(f"base:{model_ckpt}")
|
115 |
+
|
116 |
+
if "/" in model_ckpt:
|
117 |
+
sanitised_model_name = model_ckpt.split("/")[1]
|
118 |
+
else:
|
119 |
+
sanitised_model_name = model_ckpt
|
120 |
+
|
121 |
+
wandb.init(name=f"daedra_{SUBSAMPLING}-{sanitised_model_name}", tags=wandb_tag, magic=True)
|
122 |
+
|
123 |
+
print("Starting training...")
|
124 |
+
|
125 |
+
tracker.start()
|
126 |
+
trainer.train()
|
127 |
+
tracker.stop()
|
128 |
+
|
129 |
+
print("Training finished.")
|
130 |
+
|
131 |
+
wandb.finish()
|
132 |
+
|
133 |
+
if __name__ == "__main__":
|
134 |
+
wandb.finish()
|
135 |
+
|
136 |
+
train_from_model("dmis-lab/biobert-base-cased-v1.2")
|
notebooks/emissions.csv
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
timestamp,project_name,run_id,duration,emissions,emissions_rate,cpu_power,gpu_power,ram_power,cpu_energy,gpu_energy,ram_energy,energy_consumed,country_name,country_iso_code,region,cloud_provider,cloud_region,os,python_version,codecarbon_version,cpu_count,cpu_model,gpu_count,gpu_model,longitude,latitude,ram_total_size,tracking_mode,on_cloud,pue
|
2 |
+
2024-01-29T03:05:13,codecarbon,6bfec408-4fcc-427a-8e94-0cabc9332665,10637.685039520264,0.9110964852888171,8.564800348045516e-05,42.5,148.01654697980965,165.33123922348022,0.1255709600533049,1.8546672673437372,0.4879591008899785,2.468197328287024,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,2f,35031.49867606163,3.0209420758007166,8.623502247892816e-05,42.5,148.9286737800754,165.33123922348022,0.413520891640749,6.163406995166088,1.6069267112465608,8.183854598053408,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,440440.88330459594727,machine,N,1.0
|
3 |
+
2024-01-29T14:29:57,codecarbon,484aeab5-8bdc-4fbc-8f66-0c204b0f2a.88330459594727,machine,N,1.0
|
notebooks/emissions.csv.amltmp
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
timestamp,project_name,run_id,duration,emissions,emissions_rate,cpu_power,gpu_power,ram_power,cpu_energy,gpu_energy,ram_energy,energy_consumed,country_name,country_iso_code,region,cloud_provider,cloud_region,os,python_version,codecarbon_version,cpu_count,cpu_model,gpu_count,gpu_model,longitude,latitude,ram_total_size,tracking_mode,on_cloud,pue
|
2 |
+
2024-01-29T03:05:13,codecarbon,6bfec408-4fcc-427a-8e94-0cabc9332665,10637.685039520264,0.9110964852888171,8.564800348045516e-05,42.5,148.01654697980965,165.33123922348022,0.1255709600533049,1.8546672673437372,0.4879591008899785,2.468197328287024,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,2f,35031.49867606163,3.0209420758007166,8.623502247892816e-05,42.5,148.9286737800754,165.33123922348022,0.413520891640749,6.163406995166088,1.6069267112465608,8.183854598053408,United States,USA,virginia,,,Linux-5.15.0-1040-azure-x86_64-with-glibc2.10,3.8.5,2.3.3,24,Intel(R) Xeon(R) CPU E5-2690 v4 @ 2.60GHz,4,4 x Tesla V100-PCIE-16GB,-76.8545,37.9273,440440.88330459594727,machine,N,1.0
|
3 |
+
2024-01-29T14:29:57,codecarbon,484aeab5-8bdc-4fbc-8f66-0c204b0f2a.88330459594727,machine,N,1.0
|
notebooks/microsample_model_comparison.ipynb
ADDED
File without changes
|
notebooks/tokenizer.json
ADDED
The diff for this file is too large to render.
See raw diff
|
|
notebooks/wandb/.amlignore
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
notebooks/wandb/.amlignore.amltmp
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
## This file was auto generated by the Azure Machine Learning Studio. Please do not remove.
|
2 |
+
## Read more about the .amlignore file here: https://docs.microsoft.com/azure/machine-learning/how-to-save-write-experiment-files#storage-limits-of-experiment-snapshots
|
3 |
+
|
4 |
+
.ipynb_aml_checkpoints/
|
5 |
+
*.amltmp
|
6 |
+
*.amltemp
|
paper/.gitkeep
ADDED
File without changes
|
tokenizer.json
CHANGED
@@ -1,6 +1,11 @@
|
|
1 |
{
|
2 |
"version": "1.0",
|
3 |
-
"truncation":
|
|
|
|
|
|
|
|
|
|
|
4 |
"padding": null,
|
5 |
"added_tokens": [
|
6 |
{
|
|
|
1 |
{
|
2 |
"version": "1.0",
|
3 |
+
"truncation": {
|
4 |
+
"direction": "Right",
|
5 |
+
"max_length": 512,
|
6 |
+
"strategy": "LongestFirst",
|
7 |
+
"stride": 0
|
8 |
+
},
|
9 |
"padding": null,
|
10 |
"added_tokens": [
|
11 |
{
|
training_args.bin
CHANGED
@@ -1,3 +1,3 @@
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
-
oid sha256:
|
3 |
size 4728
|
|
|
1 |
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:c847edf58c0470c1a32d6d0f580f3c732e43c689025195de8e292f71fbb85be6
|
3 |
size 4728
|