Spaces:
Runtime error
Runtime error
rajaatif786
commited on
Commit
•
4dd3f33
0
Parent(s):
Duplicate from rajaatif786/VirBert
Browse files- .gitattributes +34 -0
- BERT/main/berttok/vocab.txt +622 -0
- README.md +13 -0
- Toxonomy/modules/__pycache__/classifier.cpython-39.pyc +0 -0
- Toxonomy/modules/__pycache__/confusionmatrix.cpython-39.pyc +0 -0
- Toxonomy/modules/__pycache__/preprocessor.cpython-39.pyc +0 -0
- Toxonomy/modules/classifier.py +526 -0
- Toxonomy/modules/confusionmatrix.py +36 -0
- Toxonomy/modules/preprocessor.py +44 -0
- app.py +82 -0
- requirements.txt +6 -0
- virBERT.pt +3 -0
.gitattributes
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
*.7z filter=lfs diff=lfs merge=lfs -text
|
2 |
+
*.arrow filter=lfs diff=lfs merge=lfs -text
|
3 |
+
*.bin filter=lfs diff=lfs merge=lfs -text
|
4 |
+
*.bz2 filter=lfs diff=lfs merge=lfs -text
|
5 |
+
*.ckpt filter=lfs diff=lfs merge=lfs -text
|
6 |
+
*.ftz filter=lfs diff=lfs merge=lfs -text
|
7 |
+
*.gz filter=lfs diff=lfs merge=lfs -text
|
8 |
+
*.h5 filter=lfs diff=lfs merge=lfs -text
|
9 |
+
*.joblib filter=lfs diff=lfs merge=lfs -text
|
10 |
+
*.lfs.* filter=lfs diff=lfs merge=lfs -text
|
11 |
+
*.mlmodel filter=lfs diff=lfs merge=lfs -text
|
12 |
+
*.model filter=lfs diff=lfs merge=lfs -text
|
13 |
+
*.msgpack filter=lfs diff=lfs merge=lfs -text
|
14 |
+
*.npy filter=lfs diff=lfs merge=lfs -text
|
15 |
+
*.npz filter=lfs diff=lfs merge=lfs -text
|
16 |
+
*.onnx filter=lfs diff=lfs merge=lfs -text
|
17 |
+
*.ot filter=lfs diff=lfs merge=lfs -text
|
18 |
+
*.parquet filter=lfs diff=lfs merge=lfs -text
|
19 |
+
*.pb filter=lfs diff=lfs merge=lfs -text
|
20 |
+
*.pickle filter=lfs diff=lfs merge=lfs -text
|
21 |
+
*.pkl filter=lfs diff=lfs merge=lfs -text
|
22 |
+
*.pt filter=lfs diff=lfs merge=lfs -text
|
23 |
+
*.pth filter=lfs diff=lfs merge=lfs -text
|
24 |
+
*.rar filter=lfs diff=lfs merge=lfs -text
|
25 |
+
*.safetensors filter=lfs diff=lfs merge=lfs -text
|
26 |
+
saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
27 |
+
*.tar.* filter=lfs diff=lfs merge=lfs -text
|
28 |
+
*.tflite filter=lfs diff=lfs merge=lfs -text
|
29 |
+
*.tgz filter=lfs diff=lfs merge=lfs -text
|
30 |
+
*.wasm filter=lfs diff=lfs merge=lfs -text
|
31 |
+
*.xz filter=lfs diff=lfs merge=lfs -text
|
32 |
+
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
+
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
+
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
BERT/main/berttok/vocab.txt
ADDED
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[CLS]
|
2 |
+
[SEP]
|
3 |
+
[mask]
|
4 |
+
a
|
5 |
+
b
|
6 |
+
c
|
7 |
+
d
|
8 |
+
g
|
9 |
+
h
|
10 |
+
k
|
11 |
+
m
|
12 |
+
n
|
13 |
+
r
|
14 |
+
s
|
15 |
+
t
|
16 |
+
v
|
17 |
+
w
|
18 |
+
y
|
19 |
+
##s
|
20 |
+
##g
|
21 |
+
##n
|
22 |
+
##a
|
23 |
+
##t
|
24 |
+
##c
|
25 |
+
##k
|
26 |
+
##y
|
27 |
+
##w
|
28 |
+
##m
|
29 |
+
##r
|
30 |
+
##v
|
31 |
+
##b
|
32 |
+
##h
|
33 |
+
##d
|
34 |
+
##aa
|
35 |
+
##ca
|
36 |
+
##tg
|
37 |
+
tg
|
38 |
+
gg
|
39 |
+
##ag
|
40 |
+
ag
|
41 |
+
##ct
|
42 |
+
##at
|
43 |
+
##cc
|
44 |
+
##ac
|
45 |
+
##tt
|
46 |
+
##tc
|
47 |
+
##ta
|
48 |
+
##cg
|
49 |
+
cg
|
50 |
+
caa
|
51 |
+
aaa
|
52 |
+
tgg
|
53 |
+
gga
|
54 |
+
aga
|
55 |
+
aca
|
56 |
+
atg
|
57 |
+
ggg
|
58 |
+
gaa
|
59 |
+
cca
|
60 |
+
aag
|
61 |
+
ctg
|
62 |
+
cag
|
63 |
+
tca
|
64 |
+
ttg
|
65 |
+
aat
|
66 |
+
agg
|
67 |
+
cat
|
68 |
+
gag
|
69 |
+
tga
|
70 |
+
aac
|
71 |
+
tgc
|
72 |
+
gca
|
73 |
+
tgt
|
74 |
+
cct
|
75 |
+
ccc
|
76 |
+
gtg
|
77 |
+
acc
|
78 |
+
ggc
|
79 |
+
cac
|
80 |
+
ctc
|
81 |
+
ctt
|
82 |
+
ttt
|
83 |
+
gct
|
84 |
+
act
|
85 |
+
gac
|
86 |
+
tct
|
87 |
+
gcc
|
88 |
+
att
|
89 |
+
tcc
|
90 |
+
gat
|
91 |
+
ttc
|
92 |
+
agc
|
93 |
+
atc
|
94 |
+
ata
|
95 |
+
ggt
|
96 |
+
gtt
|
97 |
+
agt
|
98 |
+
tac
|
99 |
+
tat
|
100 |
+
gtc
|
101 |
+
cta
|
102 |
+
taa
|
103 |
+
tta
|
104 |
+
cgg
|
105 |
+
gta
|
106 |
+
ccg
|
107 |
+
tag
|
108 |
+
gcg
|
109 |
+
cgc
|
110 |
+
acg
|
111 |
+
cgt
|
112 |
+
tcg
|
113 |
+
cga
|
114 |
+
nn
|
115 |
+
nnn
|
116 |
+
##nn
|
117 |
+
ca
|
118 |
+
ar
|
119 |
+
##gt
|
120 |
+
aa
|
121 |
+
##gg
|
122 |
+
##yt
|
123 |
+
nng
|
124 |
+
ann
|
125 |
+
ga
|
126 |
+
##ya
|
127 |
+
##ty
|
128 |
+
ac
|
129 |
+
can
|
130 |
+
tc
|
131 |
+
ngt
|
132 |
+
gr
|
133 |
+
tr
|
134 |
+
##tr
|
135 |
+
cr
|
136 |
+
##ga
|
137 |
+
cc
|
138 |
+
gc
|
139 |
+
##yc
|
140 |
+
##yg
|
141 |
+
##gc
|
142 |
+
ara
|
143 |
+
ta
|
144 |
+
##mt
|
145 |
+
gar
|
146 |
+
##tn
|
147 |
+
aar
|
148 |
+
##na
|
149 |
+
cnn
|
150 |
+
##ma
|
151 |
+
car
|
152 |
+
ayt
|
153 |
+
##nt
|
154 |
+
rat
|
155 |
+
ct
|
156 |
+
rgg
|
157 |
+
ggr
|
158 |
+
ytt
|
159 |
+
nna
|
160 |
+
rga
|
161 |
+
raa
|
162 |
+
ytg
|
163 |
+
rag
|
164 |
+
aay
|
165 |
+
crg
|
166 |
+
##wc
|
167 |
+
gty
|
168 |
+
rac
|
169 |
+
##kg
|
170 |
+
agr
|
171 |
+
arg
|
172 |
+
art
|
173 |
+
tya
|
174 |
+
tcr
|
175 |
+
##ng
|
176 |
+
rtt
|
177 |
+
naa
|
178 |
+
yat
|
179 |
+
cyt
|
180 |
+
yac
|
181 |
+
cay
|
182 |
+
rca
|
183 |
+
tty
|
184 |
+
gt
|
185 |
+
ttr
|
186 |
+
tgy
|
187 |
+
ayc
|
188 |
+
aya
|
189 |
+
aty
|
190 |
+
cya
|
191 |
+
rtc
|
192 |
+
yct
|
193 |
+
gra
|
194 |
+
tra
|
195 |
+
##mc
|
196 |
+
trt
|
197 |
+
##wt
|
198 |
+
##mg
|
199 |
+
##kc
|
200 |
+
##wg
|
201 |
+
at
|
202 |
+
nnt
|
203 |
+
nnc
|
204 |
+
acy
|
205 |
+
as
|
206 |
+
ctr
|
207 |
+
tcy
|
208 |
+
rta
|
209 |
+
tnn
|
210 |
+
yag
|
211 |
+
grt
|
212 |
+
gnn
|
213 |
+
yaa
|
214 |
+
yta
|
215 |
+
acn
|
216 |
+
gyt
|
217 |
+
##nc
|
218 |
+
##kt
|
219 |
+
aan
|
220 |
+
acr
|
221 |
+
tyt
|
222 |
+
yca
|
223 |
+
grc
|
224 |
+
ntc
|
225 |
+
ccy
|
226 |
+
gcn
|
227 |
+
gya
|
228 |
+
ggy
|
229 |
+
gay
|
230 |
+
trg
|
231 |
+
ytc
|
232 |
+
arc
|
233 |
+
rtg
|
234 |
+
rct
|
235 |
+
yga
|
236 |
+
gcy
|
237 |
+
gtr
|
238 |
+
crt
|
239 |
+
cty
|
240 |
+
cyg
|
241 |
+
ngc
|
242 |
+
tyg
|
243 |
+
ycc
|
244 |
+
grg
|
245 |
+
gs
|
246 |
+
cra
|
247 |
+
ccr
|
248 |
+
tay
|
249 |
+
cnt
|
250 |
+
ccn
|
251 |
+
ygg
|
252 |
+
ggn
|
253 |
+
atr
|
254 |
+
mtt
|
255 |
+
rcc
|
256 |
+
rgt
|
257 |
+
tyc
|
258 |
+
##wa
|
259 |
+
ayg
|
260 |
+
amt
|
261 |
+
nag
|
262 |
+
rgc
|
263 |
+
nca
|
264 |
+
sgg
|
265 |
+
cyc
|
266 |
+
nac
|
267 |
+
ygc
|
268 |
+
##sg
|
269 |
+
acm
|
270 |
+
ana
|
271 |
+
cma
|
272 |
+
ntg
|
273 |
+
ty
|
274 |
+
ygt
|
275 |
+
tgn
|
276 |
+
tgr
|
277 |
+
trc
|
278 |
+
ngg
|
279 |
+
gyg
|
280 |
+
cmt
|
281 |
+
maa
|
282 |
+
tcm
|
283 |
+
gcr
|
284 |
+
ttn
|
285 |
+
twc
|
286 |
+
agn
|
287 |
+
gyc
|
288 |
+
ncc
|
289 |
+
##ka
|
290 |
+
agy
|
291 |
+
mtc
|
292 |
+
nat
|
293 |
+
tt
|
294 |
+
crc
|
295 |
+
atn
|
296 |
+
kgg
|
297 |
+
ntt
|
298 |
+
ysg
|
299 |
+
tcn
|
300 |
+
tys
|
301 |
+
mgg
|
302 |
+
ts
|
303 |
+
tan
|
304 |
+
tmt
|
305 |
+
aam
|
306 |
+
gan
|
307 |
+
ctm
|
308 |
+
ang
|
309 |
+
mtg
|
310 |
+
nga
|
311 |
+
tar
|
312 |
+
ctn
|
313 |
+
cna
|
314 |
+
cgy
|
315 |
+
wct
|
316 |
+
wca
|
317 |
+
twt
|
318 |
+
ctk
|
319 |
+
ctw
|
320 |
+
gtw
|
321 |
+
gna
|
322 |
+
mat
|
323 |
+
nta
|
324 |
+
ggk
|
325 |
+
acw
|
326 |
+
gcm
|
327 |
+
cmg
|
328 |
+
kct
|
329 |
+
tna
|
330 |
+
ccm
|
331 |
+
awg
|
332 |
+
cwg
|
333 |
+
nct
|
334 |
+
tma
|
335 |
+
cas
|
336 |
+
tam
|
337 |
+
cmc
|
338 |
+
gkg
|
339 |
+
ant
|
340 |
+
wcc
|
341 |
+
gsa
|
342 |
+
gtn
|
343 |
+
wgc
|
344 |
+
cs
|
345 |
+
gng
|
346 |
+
ktt
|
347 |
+
mag
|
348 |
+
wtg
|
349 |
+
cgr
|
350 |
+
cak
|
351 |
+
gam
|
352 |
+
gtm
|
353 |
+
kgt
|
354 |
+
tkg
|
355 |
+
tkt
|
356 |
+
aak
|
357 |
+
ama
|
358 |
+
anc
|
359 |
+
kca
|
360 |
+
mta
|
361 |
+
sca
|
362 |
+
ggw
|
363 |
+
ccw
|
364 |
+
atm
|
365 |
+
asc
|
366 |
+
akg
|
367 |
+
amc
|
368 |
+
ckg
|
369 |
+
cwt
|
370 |
+
ckc
|
371 |
+
mca
|
372 |
+
mcc
|
373 |
+
tkc
|
374 |
+
tgs
|
375 |
+
aaw
|
376 |
+
gaw
|
377 |
+
tcw
|
378 |
+
saa
|
379 |
+
cam
|
380 |
+
atk
|
381 |
+
atw
|
382 |
+
asa
|
383 |
+
cwc
|
384 |
+
gmt
|
385 |
+
gwc
|
386 |
+
ktg
|
387 |
+
rcg
|
388 |
+
##an
|
389 |
+
tck
|
390 |
+
gtk
|
391 |
+
gnt
|
392 |
+
gnc
|
393 |
+
kga
|
394 |
+
nan
|
395 |
+
sag
|
396 |
+
scc
|
397 |
+
tng
|
398 |
+
wtt
|
399 |
+
wgt
|
400 |
+
ggs
|
401 |
+
ggm
|
402 |
+
ack
|
403 |
+
cck
|
404 |
+
akc
|
405 |
+
mct
|
406 |
+
mac
|
407 |
+
stg
|
408 |
+
tnt
|
409 |
+
waa
|
410 |
+
wtc
|
411 |
+
awt
|
412 |
+
amg
|
413 |
+
cnc
|
414 |
+
ckt
|
415 |
+
cwa
|
416 |
+
gma
|
417 |
+
kcc
|
418 |
+
sta
|
419 |
+
wta
|
420 |
+
ycg
|
421 |
+
kag
|
422 |
+
mgt
|
423 |
+
ncg
|
424 |
+
tmg
|
425 |
+
caw
|
426 |
+
tas
|
427 |
+
akt
|
428 |
+
cng
|
429 |
+
gwa
|
430 |
+
mcg
|
431 |
+
tmc
|
432 |
+
wgg
|
433 |
+
aas
|
434 |
+
gcw
|
435 |
+
asg
|
436 |
+
tsc
|
437 |
+
awc
|
438 |
+
gkc
|
439 |
+
ktc
|
440 |
+
kgc
|
441 |
+
tnc
|
442 |
+
wac
|
443 |
+
wga
|
444 |
+
tgk
|
445 |
+
agm
|
446 |
+
gas
|
447 |
+
cts
|
448 |
+
ast
|
449 |
+
ttw
|
450 |
+
ttm
|
451 |
+
tst
|
452 |
+
gkt
|
453 |
+
twa
|
454 |
+
wag
|
455 |
+
tgw
|
456 |
+
tgm
|
457 |
+
cgs
|
458 |
+
kaa
|
459 |
+
kta
|
460 |
+
mga
|
461 |
+
##ar
|
462 |
+
gak
|
463 |
+
gst
|
464 |
+
aka
|
465 |
+
cka
|
466 |
+
gmg
|
467 |
+
rr
|
468 |
+
sct
|
469 |
+
sac
|
470 |
+
gcs
|
471 |
+
gck
|
472 |
+
ats
|
473 |
+
ay
|
474 |
+
gmc
|
475 |
+
gka
|
476 |
+
kac
|
477 |
+
mgc
|
478 |
+
ng
|
479 |
+
nc
|
480 |
+
sat
|
481 |
+
stt
|
482 |
+
twg
|
483 |
+
tka
|
484 |
+
agk
|
485 |
+
agw
|
486 |
+
gsc
|
487 |
+
ngn
|
488 |
+
ncn
|
489 |
+
gwg
|
490 |
+
sga
|
491 |
+
cgm
|
492 |
+
csa
|
493 |
+
ntn
|
494 |
+
sr
|
495 |
+
yg
|
496 |
+
tcs
|
497 |
+
gts
|
498 |
+
gsg
|
499 |
+
cst
|
500 |
+
cy
|
501 |
+
cw
|
502 |
+
kcg
|
503 |
+
stc
|
504 |
+
cgn
|
505 |
+
cgk
|
506 |
+
acs
|
507 |
+
ttk
|
508 |
+
tsa
|
509 |
+
cm
|
510 |
+
rar
|
511 |
+
sgt
|
512 |
+
wcg
|
513 |
+
yyt
|
514 |
+
cgw
|
515 |
+
tts
|
516 |
+
csc
|
517 |
+
awa
|
518 |
+
csg
|
519 |
+
gwt
|
520 |
+
rt
|
521 |
+
tm
|
522 |
+
wr
|
523 |
+
wat
|
524 |
+
yt
|
525 |
+
tak
|
526 |
+
tyy
|
527 |
+
ak
|
528 |
+
gw
|
529 |
+
ma
|
530 |
+
rc
|
531 |
+
sgc
|
532 |
+
tw
|
533 |
+
tsg
|
534 |
+
wg
|
535 |
+
wa
|
536 |
+
wc
|
537 |
+
ya
|
538 |
+
yma
|
539 |
+
##st
|
540 |
+
##sr
|
541 |
+
##vt
|
542 |
+
ags
|
543 |
+
trr
|
544 |
+
ccs
|
545 |
+
taw
|
546 |
+
rra
|
547 |
+
rrt
|
548 |
+
ayk
|
549 |
+
kat
|
550 |
+
mr
|
551 |
+
ry
|
552 |
+
ryt
|
553 |
+
yc
|
554 |
+
ywc
|
555 |
+
ymg
|
556 |
+
ykt
|
557 |
+
##ks
|
558 |
+
cry
|
559 |
+
tym
|
560 |
+
ayw
|
561 |
+
aym
|
562 |
+
ygk
|
563 |
+
cyy
|
564 |
+
rts
|
565 |
+
tmm
|
566 |
+
wrg
|
567 |
+
wgr
|
568 |
+
am
|
569 |
+
ad
|
570 |
+
gy
|
571 |
+
kg
|
572 |
+
kya
|
573 |
+
ksr
|
574 |
+
mt
|
575 |
+
mc
|
576 |
+
mty
|
577 |
+
mma
|
578 |
+
mar
|
579 |
+
rs
|
580 |
+
ra
|
581 |
+
ss
|
582 |
+
sg
|
583 |
+
scg
|
584 |
+
syg
|
585 |
+
tb
|
586 |
+
tvt
|
587 |
+
tks
|
588 |
+
vtg
|
589 |
+
wkg
|
590 |
+
wmc
|
591 |
+
wst
|
592 |
+
yy
|
593 |
+
yr
|
594 |
+
ytr
|
595 |
+
ywa
|
596 |
+
##wk
|
597 |
+
##ms
|
598 |
+
grr
|
599 |
+
ctb
|
600 |
+
asr
|
601 |
+
gsy
|
602 |
+
sra
|
603 |
+
src
|
604 |
+
ygm
|
605 |
+
cyw
|
606 |
+
cwm
|
607 |
+
cwr
|
608 |
+
cmw
|
609 |
+
cmr
|
610 |
+
tmy
|
611 |
+
ytw
|
612 |
+
ytm
|
613 |
+
aky
|
614 |
+
gww
|
615 |
+
mam
|
616 |
+
rcw
|
617 |
+
twm
|
618 |
+
wak
|
619 |
+
wcr
|
620 |
+
yay
|
621 |
+
mrr
|
622 |
+
adc
|
README.md
ADDED
@@ -0,0 +1,13 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
---
|
2 |
+
title: VirBert
|
3 |
+
emoji: 👀
|
4 |
+
colorFrom: yellow
|
5 |
+
colorTo: purple
|
6 |
+
sdk: gradio
|
7 |
+
sdk_version: 3.23.0
|
8 |
+
app_file: app.py
|
9 |
+
pinned: false
|
10 |
+
duplicated_from: rajaatif786/VirBert
|
11 |
+
---
|
12 |
+
|
13 |
+
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
Toxonomy/modules/__pycache__/classifier.cpython-39.pyc
ADDED
Binary file (8.96 kB). View file
|
|
Toxonomy/modules/__pycache__/confusionmatrix.cpython-39.pyc
ADDED
Binary file (1.41 kB). View file
|
|
Toxonomy/modules/__pycache__/preprocessor.cpython-39.pyc
ADDED
Binary file (1.14 kB). View file
|
|
Toxonomy/modules/classifier.py
ADDED
@@ -0,0 +1,526 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Create the BertClassfier class
|
2 |
+
import numpy as np
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
6 |
+
device = 0
|
7 |
+
|
8 |
+
import random
|
9 |
+
import time
|
10 |
+
import torch.nn as nn
|
11 |
+
|
12 |
+
# Specify loss function
|
13 |
+
loss_fn = nn.CrossEntropyLoss()
|
14 |
+
|
15 |
+
class PretrainedBert(nn.Module):
|
16 |
+
"""Bert Model for Classification Tasks.
|
17 |
+
"""
|
18 |
+
def __init__(self, freeze_bert=False):
|
19 |
+
"""
|
20 |
+
@param bert: a BertModel object
|
21 |
+
@param classifier: a torch.nn.Module classifier
|
22 |
+
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
|
23 |
+
"""
|
24 |
+
super(PretrainedBert, self).__init__()
|
25 |
+
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
|
26 |
+
D_in, H, D_out = 768, 50, 14
|
27 |
+
# Instantiate BERT model
|
28 |
+
from transformers import BertConfig
|
29 |
+
|
30 |
+
config = BertConfig(
|
31 |
+
# we align this to the tokenizer vocab_size
|
32 |
+
max_position_embeddings=5000,
|
33 |
+
hidden_size=768,
|
34 |
+
num_attention_heads=2,
|
35 |
+
num_hidden_layers=2,
|
36 |
+
type_vocab_size=1
|
37 |
+
)
|
38 |
+
from transformers import BertForMaskedLM
|
39 |
+
|
40 |
+
self.bert =BertModel(config)
|
41 |
+
# Instantiate an one-layer feed-forward classifier
|
42 |
+
self.classifier = nn.Sequential(
|
43 |
+
nn.Linear(D_in, H),
|
44 |
+
nn.ReLU(),
|
45 |
+
#nn.Dropout(0.5),
|
46 |
+
nn.Linear(H, D_out)
|
47 |
+
)
|
48 |
+
|
49 |
+
# Freeze the BERT model
|
50 |
+
if freeze_bert:
|
51 |
+
for param in self.bert.parameters():
|
52 |
+
param.requires_grad = False
|
53 |
+
|
54 |
+
def forward(self, input_ids, attention_mask):
|
55 |
+
"""
|
56 |
+
Feed input to BERT and the classifier to compute logits.
|
57 |
+
@param input_ids (torch.Tensor): an input tensor with shape (batch_size,
|
58 |
+
max_length)
|
59 |
+
@param attention_mask (torch.Tensor): a tensor that hold attention mask
|
60 |
+
information with shape (batch_size, max_length)
|
61 |
+
@return logits (torch.Tensor): an output tensor with shape (batch_size,
|
62 |
+
num_labels)
|
63 |
+
"""
|
64 |
+
# Feed input to BERT
|
65 |
+
outputs = self.bert(input_ids=input_ids,
|
66 |
+
attention_mask=attention_mask)
|
67 |
+
|
68 |
+
# Extract the last hidden state of the token `[CLS]` for classification task
|
69 |
+
last_hidden_state_cls = outputs[0][:, 0, :]
|
70 |
+
|
71 |
+
# Feed input to classifier to compute logits
|
72 |
+
logits = self.classifier(last_hidden_state_cls)
|
73 |
+
|
74 |
+
return logits
|
75 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
76 |
+
#device='cuda'
|
77 |
+
|
78 |
+
|
79 |
+
def valid_evaluate(model, val_dataloader):
|
80 |
+
"""After the completion of each training epoch, measure the model's performance
|
81 |
+
on our validation set.
|
82 |
+
"""
|
83 |
+
# Put the model into the evaluation mode. The dropout layers are disabled during
|
84 |
+
# the test time.
|
85 |
+
model.eval()
|
86 |
+
|
87 |
+
# Tracking variables
|
88 |
+
val_accuracy = []
|
89 |
+
val_loss = []
|
90 |
+
|
91 |
+
# For each batch in our validation set...
|
92 |
+
for batch in val_dataloader:
|
93 |
+
# Load batch to GPU
|
94 |
+
b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
|
95 |
+
|
96 |
+
# Compute logits
|
97 |
+
with torch.no_grad():
|
98 |
+
logits = model(b_input_ids, b_attn_mask)
|
99 |
+
|
100 |
+
# Compute loss
|
101 |
+
loss = loss_fn(logits, b_labels)
|
102 |
+
val_loss.append(loss.item())
|
103 |
+
|
104 |
+
# Get the predictions
|
105 |
+
preds = torch.argmax(logits, dim=1).flatten()
|
106 |
+
|
107 |
+
# Calculate the accuracy rate
|
108 |
+
accuracy = (preds == b_labels).cpu().numpy().mean() * 100
|
109 |
+
val_accuracy.append(accuracy)
|
110 |
+
|
111 |
+
# Compute the average accuracy and loss over the validation set.
|
112 |
+
val_loss = np.mean(val_loss)
|
113 |
+
val_accuracy = np.mean(val_accuracy)
|
114 |
+
|
115 |
+
return val_loss, val_accuracy
|
116 |
+
|
117 |
+
|
118 |
+
|
119 |
+
import torch
|
120 |
+
import torch.nn as nn
|
121 |
+
from transformers import BertModel
|
122 |
+
|
123 |
+
# Create the BertClassfier class
|
124 |
+
class FinetunningBert(nn.Module):
|
125 |
+
"""Bert Model for Classification Tasks.
|
126 |
+
"""
|
127 |
+
def __init__(self,virus_dir, freeze_bert=False):
|
128 |
+
"""
|
129 |
+
@param bert: a BertModel object
|
130 |
+
@param classifier: a torch.nn.Module classifier
|
131 |
+
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
|
132 |
+
"""
|
133 |
+
super(FinetunningBert, self).__init__()
|
134 |
+
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
|
135 |
+
D_in, H, D_out = 768, 50, 2
|
136 |
+
# Instantiate BERT model
|
137 |
+
from transformers import BertConfig
|
138 |
+
|
139 |
+
from transformers import BertForMaskedLM
|
140 |
+
bert_classifier = PretrainedBert(freeze_bert=False)
|
141 |
+
bert_classifier.load_state_dict(torch.load(virus_dir+'/virBERT.pt'))
|
142 |
+
self.bert =bert_classifier.bert.to(device)
|
143 |
+
# Instantiate an one-layer feed-forward classifier
|
144 |
+
self.classifier = nn.Sequential(
|
145 |
+
nn.Linear(D_in, H),
|
146 |
+
nn.ReLU(),
|
147 |
+
#nn.Dropout(0.5),
|
148 |
+
nn.Linear(H, D_out)
|
149 |
+
)
|
150 |
+
|
151 |
+
# Freeze the BERT model
|
152 |
+
if freeze_bert:
|
153 |
+
for param in self.bert.parameters():
|
154 |
+
param.requires_grad = False
|
155 |
+
|
156 |
+
def forward(self, input_ids, attention_mask):
|
157 |
+
"""
|
158 |
+
Feed input to BERT and the classifier to compute logits.
|
159 |
+
@param input_ids (torch.Tensor): an input tensor with shape (batch_size,
|
160 |
+
max_length)
|
161 |
+
@param attention_mask (torch.Tensor): a tensor that hold attention mask
|
162 |
+
information with shape (batch_size, max_length)
|
163 |
+
@return logits (torch.Tensor): an output tensor with shape (batch_size,
|
164 |
+
num_labels)
|
165 |
+
"""
|
166 |
+
# Feed input to BERT
|
167 |
+
outputs = self.bert(input_ids=input_ids,
|
168 |
+
attention_mask=attention_mask)
|
169 |
+
|
170 |
+
# Extract the last hidden state of the token `[CLS]` for classification task
|
171 |
+
last_hidden_state_cls = outputs[0][:, 0, :]
|
172 |
+
|
173 |
+
# Feed input to classifier to compute logits
|
174 |
+
logits = self.classifier(last_hidden_state_cls)
|
175 |
+
|
176 |
+
return logits
|
177 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
178 |
+
#device='cuda'
|
179 |
+
def initialize_finetunningBert(train_dataloader,virus_dir,epochs=4):
|
180 |
+
"""Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
|
181 |
+
"""
|
182 |
+
# Instantiate Bert Classifier
|
183 |
+
bert_classifier = FinetunningBert(virus_dir,freeze_bert=False)
|
184 |
+
|
185 |
+
# Tell PyTorch to run the model on GPU
|
186 |
+
bert_classifier.to(device)
|
187 |
+
|
188 |
+
# Create the optimizer
|
189 |
+
optimizer = AdamW(bert_classifier.parameters(),
|
190 |
+
lr=5e-5, # Default learning rate
|
191 |
+
eps=1e-8 # Default epsilon value
|
192 |
+
)
|
193 |
+
|
194 |
+
# Total number of training steps
|
195 |
+
total_steps = len(train_dataloader) * epochs
|
196 |
+
|
197 |
+
# Set up the learning rate scheduler
|
198 |
+
scheduler = get_linear_schedule_with_warmup(optimizer,
|
199 |
+
num_warmup_steps=0, # Default value
|
200 |
+
num_training_steps=total_steps)
|
201 |
+
return bert_classifier, optimizer, scheduler
|
202 |
+
import random
|
203 |
+
import time
|
204 |
+
import torch.nn as nn
|
205 |
+
|
206 |
+
# Specify loss function
|
207 |
+
loss_fn = nn.CrossEntropyLoss()
|
208 |
+
|
209 |
+
|
210 |
+
def finetunningBert_training(model, optimizer, scheduler, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
|
211 |
+
"""Train the BertClassifier model.
|
212 |
+
"""
|
213 |
+
# Start training loop
|
214 |
+
print("Start training...\n")
|
215 |
+
for epoch_i in range(epochs):
|
216 |
+
# =======================================
|
217 |
+
# Training
|
218 |
+
# =======================================
|
219 |
+
# Print the header of the result table
|
220 |
+
print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
|
221 |
+
print("-"*70)
|
222 |
+
|
223 |
+
# Measure the elapsed time of each epoch
|
224 |
+
t0_epoch, t0_batch = time.time(), time.time()
|
225 |
+
|
226 |
+
# Reset tracking variables at the beginning of each epoch
|
227 |
+
total_loss, batch_loss, batch_counts = 0, 0, 0
|
228 |
+
|
229 |
+
# Put the model into the training mode
|
230 |
+
model.train()
|
231 |
+
|
232 |
+
# For each batch of training data...
|
233 |
+
for step, batch in enumerate(train_dataloader):
|
234 |
+
batch_counts +=1
|
235 |
+
# Load batch to GPU
|
236 |
+
b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
|
237 |
+
|
238 |
+
# Zero out any previously calculated gradients
|
239 |
+
model.zero_grad()
|
240 |
+
|
241 |
+
# Perform a forward pass. This will return logits.
|
242 |
+
logits = model(b_input_ids, b_attn_mask)
|
243 |
+
|
244 |
+
# Compute loss and accumulate the loss values
|
245 |
+
loss = loss_fn(logits, b_labels)
|
246 |
+
batch_loss += loss.item()
|
247 |
+
total_loss += loss.item()
|
248 |
+
|
249 |
+
# Perform a backward pass to calculate gradients
|
250 |
+
loss.backward()
|
251 |
+
|
252 |
+
# Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
|
253 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
254 |
+
|
255 |
+
# Update parameters and the learning rate
|
256 |
+
optimizer.step()
|
257 |
+
scheduler.step()
|
258 |
+
|
259 |
+
# Print the loss values and time elapsed for every 20 batches
|
260 |
+
if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
|
261 |
+
# Calculate time elapsed for 20 batches
|
262 |
+
time_elapsed = time.time() - t0_batch
|
263 |
+
|
264 |
+
# Print training results
|
265 |
+
print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")
|
266 |
+
|
267 |
+
# Reset batch tracking variables
|
268 |
+
batch_loss, batch_counts = 0, 0
|
269 |
+
t0_batch = time.time()
|
270 |
+
|
271 |
+
# Calculate the average loss over the entire training data
|
272 |
+
avg_train_loss = total_loss / len(train_dataloader)
|
273 |
+
torch.save(model.state_dict(), '{}model.pt'.format("VirDNA"))
|
274 |
+
print("-"*70)
|
275 |
+
# =======================================
|
276 |
+
# Evaluation
|
277 |
+
# =======================================
|
278 |
+
if evaluation == True:
|
279 |
+
# After the completion of each training epoch, measure the model's performance
|
280 |
+
# on our validation set.
|
281 |
+
val_loss, val_accuracy = valid_evaluate(model, val_dataloader)
|
282 |
+
|
283 |
+
# Print performance over the entire training data
|
284 |
+
time_elapsed = time.time() - t0_epoch
|
285 |
+
|
286 |
+
print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
|
287 |
+
print("-"*70)
|
288 |
+
print("\n")
|
289 |
+
|
290 |
+
print("Training complete!")
|
291 |
+
|
292 |
+
def bertPredictions(torch,model, val_dataloader):
|
293 |
+
"""After the completion of each training epoch, measure the model's performance
|
294 |
+
on our validation set.
|
295 |
+
"""
|
296 |
+
# Put the model into the evaluation mode. The dropout layers are disabled during
|
297 |
+
# the test time.
|
298 |
+
model.eval()
|
299 |
+
device = 0
|
300 |
+
print("working3")
|
301 |
+
|
302 |
+
# Tracking variables
|
303 |
+
val_accuracy = []
|
304 |
+
val_loss = []
|
305 |
+
pred=[]
|
306 |
+
actual=[]
|
307 |
+
# For each batch in our validation set...
|
308 |
+
for batch in val_dataloader:
|
309 |
+
device = 0
|
310 |
+
# Load batch to GPU
|
311 |
+
b_input_ids, b_attn_mask, b_labels = tuple(t for t in batch)
|
312 |
+
|
313 |
+
# Compute logits
|
314 |
+
with torch.no_grad():
|
315 |
+
logits = model(b_input_ids, b_attn_mask)
|
316 |
+
|
317 |
+
# Compute loss
|
318 |
+
#loss = loss_fn(logits, b_labels)
|
319 |
+
#val_loss.append(loss.item())
|
320 |
+
|
321 |
+
# Get the predictions
|
322 |
+
preds = torch.argmax(logits, dim=1).flatten()
|
323 |
+
|
324 |
+
# Calculate the accuracy rate
|
325 |
+
#accuracy = (preds == b_labels).cpu().numpy().mean() * 100
|
326 |
+
#val_accuracy.append(accuracy)
|
327 |
+
pred.append(preds.cpu())
|
328 |
+
#actual.append(b_labels.cpu())
|
329 |
+
|
330 |
+
# Compute the average accuracy and loss over the validation set.
|
331 |
+
#val_loss = np.mean(val_loss)
|
332 |
+
#val_accuracy = np.mean(val_accuracy)
|
333 |
+
|
334 |
+
return pred
|
335 |
+
|
336 |
+
|
337 |
+
|
338 |
+
|
339 |
+
|
340 |
+
|
341 |
+
|
342 |
+
|
343 |
+
|
344 |
+
|
345 |
+
|
346 |
+
|
347 |
+
import torch
|
348 |
+
import torch.nn as nn
|
349 |
+
from transformers import BertModel
|
350 |
+
|
351 |
+
# Create the BertClassfier class
|
352 |
+
class ScratchBert(nn.Module):
|
353 |
+
"""Bert Model for Classification Tasks.
|
354 |
+
"""
|
355 |
+
def __init__(self, freeze_bert=False):
|
356 |
+
"""
|
357 |
+
@param bert: a BertModel object
|
358 |
+
@param classifier: a torch.nn.Module classifier
|
359 |
+
@param freeze_bert (bool): Set `False` to fine-tune the BERT model
|
360 |
+
"""
|
361 |
+
super(ScratchBert, self).__init__()
|
362 |
+
# Specify hidden size of BERT, hidden size of our classifier, and number of labels
|
363 |
+
D_in, H, D_out = 768, 50, 2
|
364 |
+
# Instantiate BERT model
|
365 |
+
from transformers import BertConfig
|
366 |
+
|
367 |
+
|
368 |
+
config = BertConfig(
|
369 |
+
# we align this to the tokenizer vocab_size
|
370 |
+
max_position_embeddings=5000,
|
371 |
+
hidden_size=768,
|
372 |
+
num_attention_heads=2,
|
373 |
+
num_hidden_layers=2,
|
374 |
+
type_vocab_size=1
|
375 |
+
)
|
376 |
+
from transformers import BertForMaskedLM
|
377 |
+
|
378 |
+
self.bert =BertModel(config)
|
379 |
+
# Instantiate an one-layer feed-forward classifier
|
380 |
+
self.classifier = nn.Sequential(
|
381 |
+
nn.Linear(D_in, H),
|
382 |
+
nn.ReLU(),
|
383 |
+
#nn.Dropout(0.5),
|
384 |
+
nn.Linear(H, D_out)
|
385 |
+
)
|
386 |
+
|
387 |
+
# Freeze the BERT model
|
388 |
+
if freeze_bert:
|
389 |
+
for param in self.bert.parameters():
|
390 |
+
param.requires_grad = False
|
391 |
+
|
392 |
+
def forward(self, input_ids, attention_mask):
|
393 |
+
"""
|
394 |
+
Feed input to BERT and the classifier to compute logits.
|
395 |
+
@param input_ids (torch.Tensor): an input tensor with shape (batch_size,
|
396 |
+
max_length)
|
397 |
+
@param attention_mask (torch.Tensor): a tensor that hold attention mask
|
398 |
+
information with shape (batch_size, max_length)
|
399 |
+
@return logits (torch.Tensor): an output tensor with shape (batch_size,
|
400 |
+
num_labels)
|
401 |
+
"""
|
402 |
+
# Feed input to BERT
|
403 |
+
outputs = self.bert(input_ids=input_ids,
|
404 |
+
attention_mask=attention_mask)
|
405 |
+
|
406 |
+
# Extract the last hidden state of the token `[CLS]` for classification task
|
407 |
+
last_hidden_state_cls = outputs[0][:, 0, :]
|
408 |
+
|
409 |
+
# Feed input to classifier to compute logits
|
410 |
+
logits = self.classifier(last_hidden_state_cls)
|
411 |
+
|
412 |
+
return logits
|
413 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
414 |
+
#device='cuda'
|
415 |
+
def initialize_model(train_dataloader,epochs=4):
|
416 |
+
"""Initialize the Bert Classifier, the optimizer and the learning rate scheduler.
|
417 |
+
"""
|
418 |
+
# Instantiate Bert Classifier
|
419 |
+
bert_classifier = ScratchBert(freeze_bert=False)
|
420 |
+
|
421 |
+
# Tell PyTorch to run the model on GPU
|
422 |
+
bert_classifier.to(device)
|
423 |
+
|
424 |
+
# Create the optimizer
|
425 |
+
optimizer = AdamW(bert_classifier.parameters(),
|
426 |
+
lr=5e-5, # Default learning rate
|
427 |
+
eps=1e-8 # Default epsilon value
|
428 |
+
)
|
429 |
+
|
430 |
+
# Total number of training steps
|
431 |
+
total_steps = len(train_dataloader) * epochs
|
432 |
+
|
433 |
+
# Set up the learning rate scheduler
|
434 |
+
scheduler = get_linear_schedule_with_warmup(optimizer,
|
435 |
+
num_warmup_steps=0, # Default value
|
436 |
+
num_training_steps=total_steps)
|
437 |
+
return bert_classifier, optimizer, scheduler
|
438 |
+
import random
|
439 |
+
import time
|
440 |
+
import torch.nn as nn
|
441 |
+
|
442 |
+
# Specify loss function
|
443 |
+
loss_fn = nn.CrossEntropyLoss()
|
444 |
+
|
445 |
+
|
446 |
+
def train(model,optimizer, scheduler, train_dataloader, val_dataloader=None, epochs=4, evaluation=False):
|
447 |
+
"""Train the BertClassifier model.
|
448 |
+
"""
|
449 |
+
# Start training loop
|
450 |
+
print("Start training...\n")
|
451 |
+
for epoch_i in range(epochs):
|
452 |
+
# =======================================
|
453 |
+
# Training
|
454 |
+
# =======================================
|
455 |
+
# Print the header of the result table
|
456 |
+
print(f"{'Epoch':^7} | {'Batch':^7} | {'Train Loss':^12} | {'Val Loss':^10} | {'Val Acc':^9} | {'Elapsed':^9}")
|
457 |
+
print("-"*70)
|
458 |
+
|
459 |
+
# Measure the elapsed time of each epoch
|
460 |
+
t0_epoch, t0_batch = time.time(), time.time()
|
461 |
+
|
462 |
+
# Reset tracking variables at the beginning of each epoch
|
463 |
+
total_loss, batch_loss, batch_counts = 0, 0, 0
|
464 |
+
|
465 |
+
# Put the model into the training mode
|
466 |
+
model.train()
|
467 |
+
|
468 |
+
# For each batch of training data...
|
469 |
+
for step, batch in enumerate(train_dataloader):
|
470 |
+
batch_counts +=1
|
471 |
+
# Load batch to GPU
|
472 |
+
b_input_ids, b_attn_mask, b_labels = tuple(t.to(device) for t in batch)
|
473 |
+
|
474 |
+
# Zero out any previously calculated gradients
|
475 |
+
model.zero_grad()
|
476 |
+
|
477 |
+
# Perform a forward pass. This will return logits.
|
478 |
+
logits = model(b_input_ids, b_attn_mask)
|
479 |
+
|
480 |
+
# Compute loss and accumulate the loss values
|
481 |
+
loss = loss_fn(logits, b_labels)
|
482 |
+
batch_loss += loss.item()
|
483 |
+
total_loss += loss.item()
|
484 |
+
|
485 |
+
# Perform a backward pass to calculate gradients
|
486 |
+
loss.backward()
|
487 |
+
|
488 |
+
# Clip the norm of the gradients to 1.0 to prevent "exploding gradients"
|
489 |
+
torch.nn.utils.clip_grad_norm_(model.parameters(), 1.0)
|
490 |
+
|
491 |
+
# Update parameters and the learning rate
|
492 |
+
optimizer.step()
|
493 |
+
scheduler.step()
|
494 |
+
|
495 |
+
# Print the loss values and time elapsed for every 20 batches
|
496 |
+
if (step % 20 == 0 and step != 0) or (step == len(train_dataloader) - 1):
|
497 |
+
# Calculate time elapsed for 20 batches
|
498 |
+
time_elapsed = time.time() - t0_batch
|
499 |
+
|
500 |
+
# Print training results
|
501 |
+
print(f"{epoch_i + 1:^7} | {step:^7} | {batch_loss / batch_counts:^12.6f} | {'-':^10} | {'-':^9} | {time_elapsed:^9.2f}")
|
502 |
+
|
503 |
+
# Reset batch tracking variables
|
504 |
+
batch_loss, batch_counts = 0, 0
|
505 |
+
t0_batch = time.time()
|
506 |
+
|
507 |
+
# Calculate the average loss over the entire training data
|
508 |
+
avg_train_loss = total_loss / len(train_dataloader)
|
509 |
+
torch.save(model.state_dict(), '{}model.pt'.format("VirDNA"))
|
510 |
+
print("-"*70)
|
511 |
+
# =======================================
|
512 |
+
# Evaluation
|
513 |
+
# =======================================
|
514 |
+
if evaluation == True:
|
515 |
+
# After the completion of each training epoch, measure the model's performance
|
516 |
+
# on our validation set.
|
517 |
+
val_loss, val_accuracy = valid_evaluate(model, val_dataloader)
|
518 |
+
|
519 |
+
# Print performance over the entire training data
|
520 |
+
time_elapsed = time.time() - t0_epoch
|
521 |
+
|
522 |
+
print(f"{epoch_i + 1:^7} | {'-':^7} | {avg_train_loss:^12.6f} | {val_loss:^10.6f} | {val_accuracy:^9.2f} | {time_elapsed:^9.2f}")
|
523 |
+
print("-"*70)
|
524 |
+
print("\n")
|
525 |
+
|
526 |
+
print("Training complete!")
|
Toxonomy/modules/confusionmatrix.py
ADDED
@@ -0,0 +1,36 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import matplotlib.pyplot as plt
|
2 |
+
import numpy as np
|
3 |
+
def plot_confusion_matrix(cm, classes,
|
4 |
+
normalize=False,
|
5 |
+
title='Confusion matrix',
|
6 |
+
cmap=plt.cm.Greens):
|
7 |
+
"""
|
8 |
+
This function prints and plots the confusion matrix.
|
9 |
+
Normalization can be applied by setting `normalize=True`.
|
10 |
+
"""
|
11 |
+
import itertools
|
12 |
+
if normalize:
|
13 |
+
cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
|
14 |
+
print("Normalized confusion matrix")
|
15 |
+
else:
|
16 |
+
print('Confusion matrix, without normalization')
|
17 |
+
|
18 |
+
print(cm)
|
19 |
+
|
20 |
+
plt.imshow(cm, interpolation='nearest', cmap=cmap)
|
21 |
+
plt.title(title)
|
22 |
+
plt.colorbar()
|
23 |
+
tick_marks = np.arange(len(classes))
|
24 |
+
plt.xticks(tick_marks, classes, rotation=45)
|
25 |
+
plt.yticks(tick_marks, classes)
|
26 |
+
|
27 |
+
fmt = '.2f' if normalize else 'd'
|
28 |
+
thresh = cm.max() / 2.
|
29 |
+
for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
|
30 |
+
plt.text(j, i, format(cm[i, j], fmt),
|
31 |
+
horizontalalignment="center",
|
32 |
+
color="white" if cm[i, j] > thresh else "black")
|
33 |
+
|
34 |
+
plt.ylabel('True label')
|
35 |
+
plt.xlabel('Predicted label')
|
36 |
+
plt.tight_layout()
|
Toxonomy/modules/preprocessor.py
ADDED
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from transformers import BertTokenizer
|
2 |
+
import torch
|
3 |
+
# Load the BERT tokenizer
|
4 |
+
tokenizer = BertTokenizer.from_pretrained('.'+'/BERT/main/berttok')
|
5 |
+
# Create a function to tokenize a set of texts
|
6 |
+
def preprocessing_for_bert(data):
|
7 |
+
"""Perform required preprocessing steps for pretrained BERT.
|
8 |
+
@param data (np.array): Array of texts to be processed.
|
9 |
+
@return input_ids (torch.Tensor): Tensor of token ids to be fed to a model.
|
10 |
+
@return attention_masks (torch.Tensor): Tensor of indices specifying which
|
11 |
+
tokens should be attended to by the model.
|
12 |
+
"""
|
13 |
+
# Create empty lists to store outputs
|
14 |
+
input_ids = []
|
15 |
+
attention_masks = []
|
16 |
+
#MAX_LEN=100
|
17 |
+
# For every sentence...
|
18 |
+
for sent in data:
|
19 |
+
|
20 |
+
# `encode_plus` will:
|
21 |
+
# (1) Tokenize the sentence
|
22 |
+
# (2) Add the `[CLS]` and `[SEP]` token to the start and end
|
23 |
+
# (3) Truncate/Pad sentence to max length
|
24 |
+
# (4) Map tokens to their IDs
|
25 |
+
# (5) Create attention mask
|
26 |
+
# (6) Return a dictionary of outputs
|
27 |
+
encoded_sent = tokenizer.encode_plus(
|
28 |
+
text=sent, # Preprocess sentence
|
29 |
+
add_special_tokens=True, # Add `[CLS]` and `[SEP]`
|
30 |
+
max_length=5000, # Max length to truncate/pad
|
31 |
+
pad_to_max_length=True, # Pad sentence to max length
|
32 |
+
#return_tensors='pt', # Return PyTorch tensor
|
33 |
+
return_attention_mask=True # Return attention mask
|
34 |
+
)
|
35 |
+
|
36 |
+
# Add the outputs to the lists
|
37 |
+
input_ids.append(encoded_sent.get('input_ids'))
|
38 |
+
attention_masks.append(encoded_sent.get('attention_mask'))
|
39 |
+
|
40 |
+
# Convert lists to tensors
|
41 |
+
input_ids = torch.tensor(input_ids)
|
42 |
+
attention_masks = torch.tensor(attention_masks)
|
43 |
+
|
44 |
+
return input_ids, attention_masks
|
app.py
ADDED
@@ -0,0 +1,82 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import gradio as gr
|
2 |
+
|
3 |
+
# from transformers import pipeline
|
4 |
+
|
5 |
+
# pipe = pipeline("translation", model="Helsinki-NLP/opus-mt-en-es")
|
6 |
+
|
7 |
+
# def predict(text):
|
8 |
+
# return pipe(text)[0]["translation_text"]
|
9 |
+
|
10 |
+
# iface = gr.Interface(
|
11 |
+
# fn=predict,
|
12 |
+
# inputs='text',
|
13 |
+
# outputs='text',
|
14 |
+
# examples=[["Hello! My name is Omar"]]
|
15 |
+
# )
|
16 |
+
|
17 |
+
# iface.launch()
|
18 |
+
|
19 |
+
|
20 |
+
from Toxonomy.modules.confusionmatrix import plot_confusion_matrix
|
21 |
+
import glob
|
22 |
+
import pandas as pd
|
23 |
+
|
24 |
+
from Toxonomy.modules.preprocessor import preprocessing_for_bert
|
25 |
+
print("hello")
|
26 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
27 |
+
import torch
|
28 |
+
import numpy as np
|
29 |
+
|
30 |
+
from sklearn.model_selection import train_test_split
|
31 |
+
|
32 |
+
import torch
|
33 |
+
import torch.nn as nn
|
34 |
+
from transformers import BertModel
|
35 |
+
from Toxonomy.modules.classifier import PretrainedBert,FinetunningBert,initialize_finetunningBert,finetunningBert_training,bertPredictions
|
36 |
+
|
37 |
+
from transformers import AdamW, get_linear_schedule_with_warmup
|
38 |
+
|
39 |
+
device = 0
|
40 |
+
|
41 |
+
import random
|
42 |
+
import time
|
43 |
+
import torch.nn as nn
|
44 |
+
print("completed")
|
45 |
+
def Kmers_funct(seq, size):
|
46 |
+
return [seq[x:x+size].lower() for x in range(len(seq) - size + 1)]
|
47 |
+
|
48 |
+
def kmers_sentences(mySeq):
|
49 |
+
#Kmers_funct(mySeq, size=7)
|
50 |
+
words = Kmers_funct(mySeq, size=3)
|
51 |
+
joined_sentence = ' '.join(words)
|
52 |
+
return joined_sentence
|
53 |
+
import re
|
54 |
+
from torch.utils.data import TensorDataset, DataLoader, RandomSampler, SequentialSampler
|
55 |
+
import torch
|
56 |
+
import numpy as np
|
57 |
+
def predict(text):
|
58 |
+
device = 0
|
59 |
+
print(text)
|
60 |
+
temp_df=pd.DataFrame([text]).astype('str')
|
61 |
+
temp_df.columns=['seq']
|
62 |
+
mask = temp_df['seq'].str.len() <= 7000
|
63 |
+
temp_df = temp_df.loc[mask]
|
64 |
+
temp_df['Processed']=temp_df['seq'].apply(kmers_sentences) #.reset_index()
|
65 |
+
test_inputs, test_masks = preprocessing_for_bert(temp_df['Processed'])
|
66 |
+
test_data = TensorDataset(test_inputs, test_masks, test_masks)
|
67 |
+
test_sampler = RandomSampler(test_data)
|
68 |
+
test_dataloader = DataLoader(test_data, sampler=test_sampler, batch_size=8)
|
69 |
+
bert_classifier = PretrainedBert(freeze_bert=False)
|
70 |
+
bert_classifier.load_state_dict(torch.load("./virBERT.pt",map_location=torch.device('cpu')))
|
71 |
+
print("location")
|
72 |
+
print(next(bert_classifier.parameters()).is_cuda)
|
73 |
+
#bert_classifier.to(device)
|
74 |
+
pred=bertPredictions(torch,bert_classifier,test_dataloader)
|
75 |
+
return str(pred)
|
76 |
+
iface = gr.Interface(
|
77 |
+
fn=predict,
|
78 |
+
inputs='text',
|
79 |
+
outputs='text'
|
80 |
+
)
|
81 |
+
|
82 |
+
iface.launch()
|
requirements.txt
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
transformers
|
2 |
+
tensorflow
|
3 |
+
sentencepiece
|
4 |
+
tokenizers
|
5 |
+
torch
|
6 |
+
scikit-learn
|
virBERT.pt
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:91ba86b1a74084de062f7b9d70af958d8c781e2958b582cc96574177cd9b3b68
|
3 |
+
size 168411425
|