Joblib
ynuozhang commited on
Commit
baf3373
·
1 Parent(s): b8c6018

update code

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. metrics/nonfouling/train_predictions_binary.csv +0 -0
  2. metrics/nonfouling/val_predictions_binary.csv +3 -3438
  3. tokenizer/.ipynb_checkpoints/my_tokenizers-checkpoint.py +398 -0
  4. tokenizer/__pycache__/my_tokenizers.cpython-310.pyc +0 -0
  5. tokenizer/my_tokenizers.py +398 -0
  6. tokenizer/new_splits.txt +159 -0
  7. tokenizer/new_vocab.txt +586 -0
  8. training_classifiers/.gitignore +0 -0
  9. training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py +132 -0
  10. training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py +847 -0
  11. training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py +414 -0
  12. training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash +31 -0
  13. training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py +508 -0
  14. training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py +309 -0
  15. training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt +234 -0
  16. training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py +417 -0
  17. training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py +468 -0
  18. training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py +410 -0
  19. training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py +426 -0
  20. training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py +420 -0
  21. training_data_cleaned/data_split.ipynb → training_classifiers/binding_affinity/val_smiles_pooled.csv +2 -2
  22. training_data_cleaned/nf_smiles_train.csv → training_classifiers/binding_affinity/val_smiles_unpooled.csv +2 -2
  23. training_data_cleaned/smiles_data_split.ipynb → training_classifiers/binding_affinity/val_wt_pooled.csv +2 -2
  24. training_data_cleaned/nf_smiles_val.csv → training_classifiers/binding_affinity/val_wt_unpooled.csv +2 -2
  25. training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt +3 -0
  26. training_classifiers/binding_affinity/wt_smiles_pooled/best_params.json +10 -0
  27. training_classifiers/binding_affinity/wt_smiles_pooled/optuna_trials.csv +3 -0
  28. training_classifiers/binding_affinity/wt_smiles_unpooled/.ipynb_checkpoints/best_params-checkpoint.json +10 -0
  29. training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt +3 -0
  30. training_classifiers/binding_affinity/wt_smiles_unpooled/best_params.json +10 -0
  31. training_classifiers/binding_affinity/wt_smiles_unpooled/optuna_trials.csv +3 -0
  32. training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv +3 -0
  33. training_classifiers/binding_affinity/wt_wt_pooled/best_model.pt +3 -0
  34. training_classifiers/binding_affinity/wt_wt_pooled/best_params.json +10 -0
  35. training_classifiers/binding_affinity/wt_wt_pooled/optuna_trials.csv +3 -0
  36. training_classifiers/binding_affinity/wt_wt_unpooled/best_model.pt +3 -0
  37. training_classifiers/binding_affinity/wt_wt_unpooled/best_params.json +10 -0
  38. training_classifiers/binding_affinity/wt_wt_unpooled/optuna_trials.csv +3 -0
  39. training_classifiers/binding_affinity_iptm.py +132 -0
  40. training_classifiers/binding_affinity_split.py +847 -0
  41. training_classifiers/binding_training.py +414 -0
  42. training_classifiers/binding_wt.bash +31 -0
  43. training_classifiers/hemolysis/cnn_smiles/best_model.pt +3 -0
  44. training_classifiers/hemolysis/cnn_smiles/best_model_benchmark.json +39 -0
  45. training_classifiers/hemolysis/cnn_smiles/optimization_summary.txt +19 -0
  46. training_classifiers/hemolysis/cnn_smiles/pr_curve.png +0 -0
  47. training_classifiers/hemolysis/cnn_smiles/roc_curve.png +0 -0
  48. training_classifiers/hemolysis/cnn_smiles/study_trials.csv +3 -0
  49. training_classifiers/hemolysis/cnn_smiles/train_predictions.csv +3 -0
  50. training_classifiers/hemolysis/cnn_smiles/val_predictions.csv +3 -0
metrics/nonfouling/train_predictions_binary.csv CHANGED
The diff for this file is too large to render. See raw diff
 
metrics/nonfouling/val_predictions_binary.csv CHANGED
@@ -1,3438 +1,3 @@
1
- True Label,Predicted Probability,Predicted Label
2
- 0,0.21203287,0
3
- 1,0.7625684,1
4
- 0,0.0048168427,0
5
- 0,0.0016095717,0
6
- 0,0.0010283176,0
7
- 1,0.6661874,1
8
- 0,0.25750402,0
9
- 0,0.4311336,0
10
- 0,0.00085044367,0
11
- 0,0.0011397039,0
12
- 0,0.00088075985,0
13
- 0,0.0020098046,0
14
- 0,0.0010203379,0
15
- 0,0.0008713204,0
16
- 0,0.0017780334,0
17
- 0,0.00342609,0
18
- 0,0.00080001954,0
19
- 0,0.0027687384,0
20
- 0,0.5390543,1
21
- 0,0.8238043,1
22
- 0,0.0012005109,0
23
- 0,0.00097379234,0
24
- 0,0.10598443,0
25
- 0,0.25473288,0
26
- 0,0.0013989281,0
27
- 1,0.8493596,1
28
- 0,0.3361668,0
29
- 0,0.0011108345,0
30
- 0,0.00085180654,0
31
- 0,0.57711107,1
32
- 0,0.6323541,1
33
- 0,0.00092630496,0
34
- 0,0.000880342,0
35
- 1,0.7630482,1
36
- 0,0.14976597,0
37
- 1,0.6572999,1
38
- 0,0.0013483333,0
39
- 0,0.0010882191,0
40
- 0,0.0009723114,0
41
- 0,0.00078559143,0
42
- 0,0.33249167,0
43
- 0,0.0010722216,0
44
- 0,0.00078953034,0
45
- 1,0.6428388,1
46
- 1,0.24572088,0
47
- 1,0.6876133,1
48
- 0,0.20859648,0
49
- 0,0.0020129737,0
50
- 0,0.56601,1
51
- 0,0.0012045301,0
52
- 0,0.0010913766,0
53
- 0,0.00096886675,0
54
- 1,0.59327847,1
55
- 0,0.24701594,0
56
- 0,0.00094282435,0
57
- 0,0.00080143235,0
58
- 0,0.00093203504,0
59
- 1,0.38815123,0
60
- 0,0.1851648,0
61
- 0,0.0012724681,0
62
- 0,0.5877677,1
63
- 0,0.00086790195,0
64
- 0,0.00084711233,0
65
- 0,0.00089334225,0
66
- 0,0.07253498,0
67
- 0,0.003544662,0
68
- 0,0.06225388,0
69
- 0,0.5347445,1
70
- 0,0.0015253695,0
71
- 0,0.32455897,0
72
- 0,0.0011612711,0
73
- 0,0.7157994,1
74
- 0,0.0011453595,0
75
- 0,0.001003294,0
76
- 0,0.0008531371,0
77
- 1,0.69131196,1
78
- 0,0.00073970045,0
79
- 0,0.00097454654,0
80
- 0,0.0009003313,0
81
- 1,0.6489763,1
82
- 0,0.0010982461,0
83
- 0,0.00080810016,0
84
- 0,0.0012549972,0
85
- 1,0.44953474,0
86
- 0,0.20876294,0
87
- 0,0.73931056,1
88
- 0,0.0009876546,0
89
- 0,0.29462275,0
90
- 0,0.39917734,0
91
- 0,0.000782265,0
92
- 0,0.42951488,0
93
- 0,0.001463046,0
94
- 1,0.68715686,1
95
- 1,0.42291453,0
96
- 0,0.0022332973,0
97
- 0,0.0009480977,0
98
- 0,0.0024130554,0
99
- 0,0.06350319,0
100
- 1,0.66712207,1
101
- 0,0.00079292513,0
102
- 0,0.0009931386,0
103
- 0,0.6005864,1
104
- 0,0.0008592792,0
105
- 0,0.12034535,0
106
- 1,0.6524521,1
107
- 0,0.0009471925,0
108
- 0,0.04952685,0
109
- 0,0.081261575,0
110
- 0,0.00092108996,0
111
- 0,0.43716368,0
112
- 0,0.0010192527,0
113
- 0,0.0009138947,0
114
- 0,0.00087209453,0
115
- 0,0.0007799222,0
116
- 1,0.71242714,1
117
- 0,0.0009697011,0
118
- 0,0.25283346,0
119
- 0,0.0033405088,0
120
- 0,0.0009034709,0
121
- 0,0.0027930748,0
122
- 0,0.0011101163,0
123
- 1,0.34234285,0
124
- 0,0.0009976418,0
125
- 0,0.0009782554,0
126
- 0,0.0010062164,0
127
- 0,0.0012275511,0
128
- 0,0.0007695558,0
129
- 0,0.0019192833,0
130
- 0,0.15324375,0
131
- 0,0.0008446999,0
132
- 0,0.0010133649,0
133
- 0,0.00087288895,0
134
- 0,0.1176671,0
135
- 0,0.0011782635,0
136
- 0,0.4397572,0
137
- 0,0.0007951967,0
138
- 0,0.0012110751,0
139
- 0,0.00088101206,0
140
- 0,0.12597291,0
141
- 1,0.5714519,1
142
- 1,0.50501525,1
143
- 0,0.0012025698,0
144
- 0,0.5724967,1
145
- 0,0.09112893,0
146
- 0,0.0013447981,0
147
- 0,0.0011533678,0
148
- 1,0.57831836,1
149
- 0,0.48726425,0
150
- 0,0.0011903379,0
151
- 0,0.00084324833,0
152
- 0,0.00082829024,0
153
- 0,0.085955195,0
154
- 0,0.0008221822,0
155
- 1,0.62516063,1
156
- 0,0.0010661274,0
157
- 0,0.47287834,0
158
- 0,0.30710956,0
159
- 0,0.01076754,0
160
- 1,0.21308176,0
161
- 1,0.7633791,1
162
- 0,0.45939833,0
163
- 0,0.0010105146,0
164
- 1,0.7428216,1
165
- 0,0.15860216,0
166
- 1,0.2822227,0
167
- 0,0.0011596903,0
168
- 0,0.00090244703,0
169
- 0,0.0011741482,0
170
- 1,0.53843015,1
171
- 0,0.0031817604,0
172
- 0,0.0009357714,0
173
- 0,0.00084562885,0
174
- 0,0.40793115,0
175
- 0,0.0009336455,0
176
- 0,0.0012610556,0
177
- 0,0.0009685405,0
178
- 0,0.0008348695,0
179
- 0,0.00084012165,0
180
- 0,0.0011492906,0
181
- 0,0.0009675606,0
182
- 0,0.6298985,1
183
- 0,0.756409,1
184
- 0,0.0012567933,0
185
- 0,0.67586565,1
186
- 0,0.00087434775,0
187
- 0,0.49520925,0
188
- 0,0.00083109684,0
189
- 0,0.0015268176,0
190
- 0,0.0009127699,0
191
- 0,0.47751015,0
192
- 0,0.0009006195,0
193
- 0,0.0015889019,0
194
- 0,0.27180874,0
195
- 1,0.6914706,1
196
- 0,0.789135,1
197
- 0,0.0009752872,0
198
- 0,0.000897456,0
199
- 0,0.0011121263,0
200
- 0,0.0015918579,0
201
- 0,0.6213229,1
202
- 0,0.6309986,1
203
- 0,0.0020146987,0
204
- 0,0.4184653,0
205
- 0,0.0020128626,0
206
- 0,0.21320935,0
207
- 0,0.0008362684,0
208
- 0,0.001030758,0
209
- 0,0.0011157958,0
210
- 0,0.0009790816,0
211
- 1,0.7094025,1
212
- 1,0.9093414,1
213
- 0,0.17732719,0
214
- 0,0.0009374656,0
215
- 0,0.00094623194,0
216
- 1,0.8687336,1
217
- 0,0.65683585,1
218
- 0,0.001121003,0
219
- 0,0.590965,1
220
- 0,0.6421117,1
221
- 0,0.0010331942,0
222
- 0,0.0009423399,0
223
- 0,0.26386958,0
224
- 0,0.0009064185,0
225
- 0,0.016375644,0
226
- 0,0.001191659,0
227
- 0,0.17972796,0
228
- 0,0.8418873,1
229
- 0,0.0009111323,0
230
- 0,0.26404643,0
231
- 0,0.0010872352,0
232
- 0,0.0017061317,0
233
- 0,0.81021,1
234
- 0,0.00087346835,0
235
- 0,0.14280483,0
236
- 0,0.16507958,0
237
- 0,0.0008098925,0
238
- 0,0.00083044183,0
239
- 0,0.0012201187,0
240
- 0,0.0010541711,0
241
- 0,0.28025955,0
242
- 0,0.0012255623,0
243
- 0,0.0010174535,0
244
- 0,0.0016367439,0
245
- 0,0.73789763,1
246
- 0,0.32007608,0
247
- 0,0.12180342,0
248
- 0,0.0015864924,0
249
- 0,0.70650285,1
250
- 0,0.00086796284,0
251
- 0,0.0009819808,0
252
- 0,0.099008076,0
253
- 0,0.19955282,0
254
- 1,0.46594706,0
255
- 0,0.00096476724,0
256
- 1,0.7014958,1
257
- 0,0.6206068,1
258
- 1,0.40477422,0
259
- 0,0.0010020116,0
260
- 0,0.55101025,1
261
- 0,0.0010169167,0
262
- 0,0.0011705819,0
263
- 0,0.43195137,0
264
- 0,0.33231077,0
265
- 0,0.0010730977,0
266
- 1,0.6236388,1
267
- 1,0.8430124,1
268
- 0,0.0010129493,0
269
- 1,0.55188674,1
270
- 0,0.2805052,0
271
- 0,0.49679318,0
272
- 0,0.00085314247,0
273
- 0,0.0009600679,0
274
- 0,0.000983881,0
275
- 0,0.0010313862,0
276
- 0,0.0010591929,0
277
- 0,0.0014554731,0
278
- 0,0.2721196,0
279
- 0,0.0010521681,0
280
- 0,0.0008385733,0
281
- 0,0.0010760311,0
282
- 0,0.56204385,1
283
- 0,0.0009945527,0
284
- 0,0.6538143,1
285
- 0,0.000972335,0
286
- 0,0.0015788077,0
287
- 0,0.00078132324,0
288
- 0,0.0009190443,0
289
- 0,0.0009206846,0
290
- 0,0.14582296,0
291
- 0,0.0009473762,0
292
- 0,0.00079042354,0
293
- 0,0.0010401446,0
294
- 0,0.3425446,0
295
- 1,0.7118903,1
296
- 0,0.0011929816,0
297
- 0,0.8595175,1
298
- 0,0.5998442,1
299
- 0,0.0008737177,0
300
- 1,0.80021775,1
301
- 0,0.0067423894,0
302
- 0,0.37208188,0
303
- 0,0.0009025443,0
304
- 1,0.4007288,0
305
- 0,0.00084359787,0
306
- 0,0.010245014,0
307
- 0,0.0009981133,0
308
- 0,0.0007901667,0
309
- 1,0.8616254,1
310
- 1,0.37549037,0
311
- 0,0.0010908762,0
312
- 1,0.82627475,1
313
- 0,0.7775751,1
314
- 1,0.64397454,1
315
- 0,0.000756293,0
316
- 0,0.000830868,0
317
- 0,0.0008689119,0
318
- 0,0.21817225,0
319
- 1,0.54163814,1
320
- 0,0.001031394,0
321
- 0,0.4650486,0
322
- 1,0.42690405,0
323
- 0,0.000841264,0
324
- 0,0.003041604,0
325
- 0,0.00094686233,0
326
- 0,0.0120981,0
327
- 1,0.12877595,0
328
- 0,0.0010682141,0
329
- 0,0.26241425,0
330
- 0,0.00917543,0
331
- 0,0.4467414,0
332
- 0,0.00091395644,0
333
- 0,0.0012513435,0
334
- 0,0.6208037,1
335
- 0,0.09325837,0
336
- 0,0.0014404579,0
337
- 0,0.0013578972,0
338
- 0,0.564982,1
339
- 0,0.0011169427,0
340
- 0,0.0013792963,0
341
- 0,0.0019788202,0
342
- 0,0.11213151,0
343
- 0,0.00093545037,0
344
- 0,0.00083710946,0
345
- 0,0.0009821211,0
346
- 0,0.33052954,0
347
- 1,0.63790786,1
348
- 0,0.38441333,0
349
- 0,0.65978384,1
350
- 0,0.16266404,0
351
- 0,0.0009782265,0
352
- 0,0.7020993,1
353
- 1,0.5939269,1
354
- 0,0.00086739147,0
355
- 1,0.80537134,1
356
- 1,0.5931398,1
357
- 0,0.0010550644,0
358
- 0,0.000939566,0
359
- 0,0.001143589,0
360
- 0,0.0015645259,0
361
- 0,0.0010082681,0
362
- 0,0.0011559983,0
363
- 0,0.0030087254,0
364
- 0,0.0019511192,0
365
- 0,0.4628644,0
366
- 0,0.46333745,0
367
- 0,0.0009037835,0
368
- 0,0.0012149046,0
369
- 0,0.0011464095,0
370
- 1,0.53254116,1
371
- 0,0.0011695515,0
372
- 0,0.0011102897,0
373
- 0,0.0015082303,0
374
- 0,0.0009924652,0
375
- 0,0.29110697,0
376
- 0,0.0009876668,0
377
- 0,0.00079951587,0
378
- 0,0.0010723416,0
379
- 0,0.0009410649,0
380
- 0,0.0044702576,0
381
- 0,0.0016339025,0
382
- 1,0.58247274,1
383
- 0,0.00095149525,0
384
- 0,0.0010258186,0
385
- 0,0.00090839405,0
386
- 1,0.5408768,1
387
- 0,0.00091459276,0
388
- 0,0.00096972886,0
389
- 0,0.000929176,0
390
- 0,0.00096466026,0
391
- 0,0.0012642274,0
392
- 0,0.0013000804,0
393
- 1,0.64949554,1
394
- 1,0.824855,1
395
- 0,0.5840037,1
396
- 0,0.0022679865,0
397
- 1,0.70146537,1
398
- 0,0.0008208898,0
399
- 1,0.5012139,1
400
- 0,0.0010077616,0
401
- 0,0.0011317601,0
402
- 1,0.5587163,1
403
- 0,0.0008192366,0
404
- 0,0.00091622194,0
405
- 1,0.26070353,0
406
- 1,0.7721127,1
407
- 0,0.0011609249,0
408
- 1,0.5806655,1
409
- 0,0.001057503,0
410
- 0,0.0009096564,0
411
- 0,0.63499725,1
412
- 1,0.77146506,1
413
- 0,0.00092794717,0
414
- 0,0.0011426172,0
415
- 1,0.6665107,1
416
- 0,0.44717062,0
417
- 1,0.7875412,1
418
- 0,0.6128087,1
419
- 0,0.0018639723,0
420
- 0,0.0011927163,0
421
- 0,0.0011212432,0
422
- 0,0.0010541681,0
423
- 1,0.759266,1
424
- 0,0.000915701,0
425
- 1,0.8248112,1
426
- 1,0.2618734,0
427
- 1,0.5829796,1
428
- 0,0.000971727,0
429
- 0,0.34199846,0
430
- 0,0.22960144,0
431
- 0,0.0008905708,0
432
- 0,0.7192157,1
433
- 0,0.5267322,1
434
- 0,0.0011434992,0
435
- 1,0.82825303,1
436
- 0,0.0007324419,0
437
- 0,0.0009348062,0
438
- 0,0.0018712084,0
439
- 1,0.7346453,1
440
- 0,0.0008732254,0
441
- 0,0.0010398608,0
442
- 1,0.78774214,1
443
- 0,0.0010178161,0
444
- 0,0.40890852,0
445
- 0,0.0007731539,0
446
- 0,0.30410865,0
447
- 1,0.6904336,1
448
- 0,0.0016686685,0
449
- 0,0.17082378,0
450
- 0,0.0019347976,0
451
- 0,0.00089052453,0
452
- 1,0.6660989,1
453
- 0,0.0010408974,0
454
- 0,0.27290353,0
455
- 0,0.009841075,0
456
- 0,0.0012475859,0
457
- 1,0.5256839,1
458
- 1,0.22151299,0
459
- 0,0.00091817125,0
460
- 0,0.5700492,1
461
- 0,0.19963185,0
462
- 0,0.0009827572,0
463
- 0,0.0008537978,0
464
- 0,0.00092485244,0
465
- 0,0.0012399766,0
466
- 0,0.0013511787,0
467
- 0,0.10416922,0
468
- 0,0.0008518869,0
469
- 0,0.000871039,0
470
- 0,0.001035327,0
471
- 0,0.000883175,0
472
- 0,0.0025706466,0
473
- 0,0.29247305,0
474
- 0,0.0008869903,0
475
- 1,0.71239984,1
476
- 1,0.7177157,1
477
- 0,0.0007620353,0
478
- 0,0.0625917,0
479
- 0,0.0009716233,0
480
- 0,0.0010923475,0
481
- 0,0.0009638545,0
482
- 0,0.0014103759,0
483
- 0,0.586873,1
484
- 0,0.1582566,0
485
- 0,0.7623745,1
486
- 0,0.00090825645,0
487
- 0,0.008724699,0
488
- 0,0.0008719578,0
489
- 0,0.0010188158,0
490
- 0,0.0008288342,0
491
- 0,0.00085399643,0
492
- 0,0.0011273371,0
493
- 0,0.62889636,1
494
- 1,0.20183899,0
495
- 0,0.0009553314,0
496
- 0,0.005830987,0
497
- 0,0.10113329,0
498
- 0,0.058883034,0
499
- 0,0.004936521,0
500
- 0,0.001236107,0
501
- 0,0.0008304486,0
502
- 0,0.0012260479,0
503
- 0,0.4102268,0
504
- 1,0.63618875,1
505
- 1,0.33070007,0
506
- 0,0.7466114,1
507
- 0,0.0008505032,0
508
- 0,0.15627518,0
509
- 1,0.53720474,1
510
- 1,0.42872614,0
511
- 0,0.0009015459,0
512
- 1,0.16489983,0
513
- 0,0.7569152,1
514
- 0,0.0009473306,0
515
- 0,0.54220945,1
516
- 0,0.0010804973,0
517
- 0,0.0007759088,0
518
- 0,0.21974401,0
519
- 0,0.0009557337,0
520
- 0,0.00080877467,0
521
- 1,0.72012144,1
522
- 1,0.6555891,1
523
- 0,0.0010442814,0
524
- 0,0.23529597,0
525
- 1,0.5538642,1
526
- 0,0.001103702,0
527
- 1,0.75086635,1
528
- 0,0.0010794887,0
529
- 0,0.00087138265,0
530
- 0,0.45431912,0
531
- 0,0.00098219,0
532
- 1,0.24641904,0
533
- 0,0.001045231,0
534
- 0,0.001125692,0
535
- 0,0.00088083575,0
536
- 1,0.7283503,1
537
- 0,0.40620965,0
538
- 0,0.0009369744,0
539
- 1,0.61685985,1
540
- 0,0.0015938416,0
541
- 0,0.0010618207,0
542
- 1,0.6549626,1
543
- 0,0.0011033998,0
544
- 0,0.0010170939,0
545
- 0,0.0009539079,0
546
- 0,0.0013007914,0
547
- 0,0.0009015459,0
548
- 0,0.0014242147,0
549
- 0,0.707509,1
550
- 0,0.33996958,0
551
- 0,0.249575,0
552
- 0,0.0009841922,0
553
- 0,0.00089403824,0
554
- 0,0.7104041,1
555
- 1,0.7468179,1
556
- 0,0.39707997,0
557
- 0,0.0008614993,0
558
- 0,0.0014179454,0
559
- 0,0.023018427,0
560
- 0,0.00091979245,0
561
- 0,0.00094296545,0
562
- 0,0.00087731396,0
563
- 0,0.0014412756,0
564
- 0,0.0010167825,0
565
- 0,0.002101245,0
566
- 0,0.0032875312,0
567
- 0,0.0008558566,0
568
- 1,0.6307645,1
569
- 0,0.777811,1
570
- 0,0.4272062,0
571
- 0,0.35077578,0
572
- 1,0.63610154,1
573
- 0,0.35240352,0
574
- 0,0.0012432288,0
575
- 0,0.0008819291,0
576
- 0,0.16675064,0
577
- 0,0.0055521415,0
578
- 0,0.0008978255,0
579
- 1,0.5747646,1
580
- 0,0.0009587978,0
581
- 1,0.7340868,1
582
- 0,0.0016573233,0
583
- 0,0.0007863164,0
584
- 0,0.00305028,0
585
- 0,0.0009701932,0
586
- 0,0.001084977,0
587
- 0,0.0009680819,0
588
- 0,0.23961695,0
589
- 0,0.31963304,0
590
- 0,0.0011128527,0
591
- 0,0.0016477697,0
592
- 0,0.00095384475,0
593
- 0,0.0012485523,0
594
- 0,0.0027906268,0
595
- 0,0.00086827046,0
596
- 0,0.0009598256,0
597
- 1,0.64515626,1
598
- 0,0.0010624658,0
599
- 1,0.43718228,0
600
- 0,0.0008661823,0
601
- 1,0.6091658,1
602
- 0,0.2584575,0
603
- 0,0.000878247,0
604
- 1,0.5297075,1
605
- 0,0.0008958644,0
606
- 0,0.63287455,1
607
- 0,0.8396176,1
608
- 0,0.0017420241,0
609
- 0,0.00093098823,0
610
- 0,0.0008039557,0
611
- 0,0.0008432391,0
612
- 0,0.34551686,0
613
- 0,0.14493409,0
614
- 0,0.001269443,0
615
- 1,0.6718914,1
616
- 0,0.0011318232,0
617
- 1,0.4247725,0
618
- 0,0.0009978332,0
619
- 0,0.0010476196,0
620
- 0,0.0008974574,0
621
- 0,0.42583707,0
622
- 0,0.00087123946,0
623
- 0,0.17696548,0
624
- 1,0.7879036,1
625
- 0,0.000978515,0
626
- 0,0.0009257359,0
627
- 0,0.0012878132,0
628
- 1,0.64483523,1
629
- 0,0.0008505065,0
630
- 1,0.43661386,0
631
- 1,0.57060575,1
632
- 0,0.42568576,0
633
- 0,0.0009259524,0
634
- 0,0.0009263901,0
635
- 0,0.00083254103,0
636
- 0,0.00087859563,0
637
- 0,0.0030385156,0
638
- 1,0.5120378,1
639
- 0,0.0009727936,0
640
- 0,0.0008510578,0
641
- 0,0.0010575671,0
642
- 1,0.6155601,1
643
- 0,0.00091628404,0
644
- 1,0.82486975,1
645
- 0,0.0011844339,0
646
- 0,0.0010148155,0
647
- 1,0.6681816,1
648
- 0,0.0011263781,0
649
- 0,0.61207116,1
650
- 0,0.62184155,1
651
- 1,0.73041785,1
652
- 0,0.0025686398,0
653
- 0,0.0009894907,0
654
- 0,0.0010610846,0
655
- 0,0.0010225485,0
656
- 0,0.001962883,0
657
- 0,0.0010426701,0
658
- 0,0.001386491,0
659
- 0,0.00080189446,0
660
- 0,0.0019367785,0
661
- 0,0.0008910609,0
662
- 0,0.0017919212,0
663
- 0,0.0009899494,0
664
- 0,0.0008948113,0
665
- 0,0.00095926056,0
666
- 0,0.3376383,0
667
- 0,0.0014171846,0
668
- 1,0.659626,1
669
- 0,0.0014162698,0
670
- 0,0.69116306,1
671
- 0,0.0009343347,0
672
- 0,0.00096477,0
673
- 0,0.000974611,0
674
- 1,0.389206,0
675
- 0,0.00087169366,0
676
- 0,0.12066001,0
677
- 0,0.0010589029,0
678
- 1,0.5742767,1
679
- 0,0.0012632777,0
680
- 0,0.0028386211,0
681
- 1,0.7000751,1
682
- 0,0.711954,1
683
- 1,0.2301944,0
684
- 0,0.593717,1
685
- 0,0.0010470055,0
686
- 1,0.20598996,0
687
- 0,0.0011602843,0
688
- 0,0.0009489052,0
689
- 0,0.0009171116,0
690
- 0,0.00077974907,0
691
- 0,0.00092091836,0
692
- 0,0.4514688,0
693
- 0,0.0008686565,0
694
- 1,0.68002725,1
695
- 0,0.0009320532,0
696
- 0,0.00096953986,0
697
- 1,0.47608092,0
698
- 1,0.43994707,0
699
- 0,0.0016875481,0
700
- 0,0.0013200458,0
701
- 1,0.30488643,0
702
- 1,0.2136204,0
703
- 1,0.7080725,1
704
- 0,0.35315382,0
705
- 0,0.0007530385,0
706
- 1,0.3350419,0
707
- 0,0.0010032223,0
708
- 0,0.0010372837,0
709
- 0,0.5047713,1
710
- 0,0.0011856316,0
711
- 1,0.5202941,1
712
- 0,0.036287144,0
713
- 0,0.0015443955,0
714
- 0,0.45689735,0
715
- 0,0.05079241,0
716
- 0,0.00078609615,0
717
- 0,0.00089042104,0
718
- 0,0.00091053615,0
719
- 1,0.8260853,1
720
- 0,0.0012496725,0
721
- 0,0.001003521,0
722
- 0,0.0014080106,0
723
- 0,0.43465498,0
724
- 1,0.7085056,1
725
- 0,0.1071419,0
726
- 0,0.38532647,0
727
- 0,0.0007924066,0
728
- 0,0.0012905765,0
729
- 0,0.38276187,0
730
- 0,0.6617229,1
731
- 1,0.34884775,0
732
- 0,0.0024217672,0
733
- 0,0.0009956957,0
734
- 0,0.25291744,0
735
- 1,0.6034158,1
736
- 0,0.0022521024,0
737
- 0,0.0009386203,0
738
- 1,0.50254047,1
739
- 0,0.00085585134,0
740
- 0,0.59543693,1
741
- 0,0.0011632922,0
742
- 0,0.5392403,1
743
- 1,0.7379359,1
744
- 0,0.615833,1
745
- 0,0.0011334324,0
746
- 1,0.6452454,1
747
- 1,0.5439059,1
748
- 1,0.4070706,0
749
- 0,0.00085760804,0
750
- 0,0.0013209702,0
751
- 0,0.00088978535,0
752
- 0,0.17897561,0
753
- 0,0.0008940443,0
754
- 0,0.0021460964,0
755
- 0,0.46505737,0
756
- 0,0.00095002423,0
757
- 0,0.0011486006,0
758
- 0,0.0008134391,0
759
- 1,0.61036927,1
760
- 1,0.82847095,1
761
- 0,0.0008939127,0
762
- 1,0.74133915,1
763
- 0,0.0010881014,0
764
- 1,0.6198983,1
765
- 0,0.00094616745,0
766
- 1,0.37103793,0
767
- 0,0.0011978437,0
768
- 0,0.21946022,0
769
- 0,0.0010989493,0
770
- 0,0.0011152511,0
771
- 1,0.6064778,1
772
- 0,0.0021220825,0
773
- 0,0.001013543,0
774
- 0,0.00082881283,0
775
- 0,0.0016203971,0
776
- 0,0.71427095,1
777
- 0,0.5296034,1
778
- 0,0.0007731554,0
779
- 0,0.0011991286,0
780
- 0,0.0014669152,0
781
- 0,0.56059104,1
782
- 0,0.0009583868,0
783
- 1,0.83952165,1
784
- 0,0.22009522,0
785
- 0,0.001713881,0
786
- 0,0.0011007866,0
787
- 0,0.22656342,0
788
- 0,0.0026020007,0
789
- 0,0.0018667651,0
790
- 0,0.5471779,1
791
- 0,0.001456346,0
792
- 0,0.0010325527,0
793
- 0,0.20448126,0
794
- 0,0.001224137,0
795
- 1,0.6375793,1
796
- 0,0.2193653,0
797
- 0,0.0008086656,0
798
- 0,0.0012699576,0
799
- 0,0.0008365637,0
800
- 0,0.57084,1
801
- 0,0.0010449449,0
802
- 1,0.20258097,0
803
- 0,0.0010269393,0
804
- 1,0.4684375,0
805
- 0,0.0009493274,0
806
- 0,0.0009980378,0
807
- 0,0.0009156465,0
808
- 0,0.6496492,1
809
- 0,0.34035468,0
810
- 0,0.0010597436,0
811
- 1,0.47263625,0
812
- 0,0.0009759152,0
813
- 1,0.78246176,1
814
- 0,0.4524046,0
815
- 1,0.38766143,0
816
- 1,0.7272198,1
817
- 0,0.0008958892,0
818
- 0,0.0010322925,0
819
- 0,0.0010094311,0
820
- 0,0.0012807564,0
821
- 0,0.5271925,1
822
- 0,0.36517152,0
823
- 0,0.10721343,0
824
- 0,0.0009403063,0
825
- 0,0.00077919016,0
826
- 0,0.0009822751,0
827
- 0,0.00076234364,0
828
- 0,0.00078162784,0
829
- 1,0.5729158,1
830
- 0,0.00089957967,0
831
- 0,0.0009848445,0
832
- 1,0.551513,1
833
- 1,0.7289652,1
834
- 0,0.0010857919,0
835
- 1,0.5924267,1
836
- 0,0.0015767976,0
837
- 0,0.0008469038,0
838
- 1,0.90326667,1
839
- 0,0.0009163141,0
840
- 0,0.0011933714,0
841
- 0,0.0011287286,0
842
- 0,0.0008222916,0
843
- 0,0.008933486,0
844
- 0,0.0018526828,0
845
- 0,0.00112532,0
846
- 0,0.08725183,0
847
- 0,0.0011224038,0
848
- 0,0.0024273458,0
849
- 0,0.0020511525,0
850
- 0,0.0021996044,0
851
- 1,0.7084392,1
852
- 0,0.0013229746,0
853
- 0,0.017357908,0
854
- 0,0.002221878,0
855
- 0,0.00085381826,0
856
- 0,0.0045482507,0
857
- 0,0.27093408,0
858
- 0,0.50731814,1
859
- 0,0.0012390758,0
860
- 1,0.56386113,1
861
- 0,0.00089642877,0
862
- 0,0.0009294908,0
863
- 0,0.0012274805,0
864
- 0,0.0009345854,0
865
- 0,0.00086570746,0
866
- 0,0.0011756949,0
867
- 0,0.38236865,0
868
- 1,0.687024,1
869
- 0,0.23687929,0
870
- 0,0.0016748687,0
871
- 1,0.6843725,1
872
- 0,0.304691,0
873
- 0,0.0009893903,0
874
- 0,0.0012283546,0
875
- 0,0.0009146128,0
876
- 0,0.7573988,1
877
- 1,0.34269,0
878
- 0,0.0034435734,0
879
- 0,0.00087736914,0
880
- 0,0.0015209204,0
881
- 0,0.0009608211,0
882
- 0,0.0009280081,0
883
- 0,0.0012842439,0
884
- 1,0.6248447,1
885
- 1,0.62274516,1
886
- 0,0.00165578,0
887
- 0,0.0008775915,0
888
- 0,0.18800546,0
889
- 1,0.777882,1
890
- 1,0.76358753,1
891
- 0,0.00093880715,0
892
- 0,0.7536075,1
893
- 0,0.3016229,0
894
- 0,0.0011477552,0
895
- 0,0.0014952337,0
896
- 0,0.0009555025,0
897
- 0,0.15660593,0
898
- 0,0.0011227432,0
899
- 0,0.001997399,0
900
- 0,0.56990355,1
901
- 0,0.3734899,0
902
- 1,0.5575483,1
903
- 1,0.6860012,1
904
- 1,0.10437922,0
905
- 1,0.8180956,1
906
- 0,0.0011188495,0
907
- 0,0.0010619572,0
908
- 0,0.56458724,1
909
- 0,0.0010933846,0
910
- 0,0.0012181381,0
911
- 0,0.4674541,0
912
- 1,0.485787,0
913
- 0,0.0013055103,0
914
- 1,0.70587736,1
915
- 0,0.00079359673,0
916
- 1,0.80720824,1
917
- 0,0.0016512059,0
918
- 1,0.7673012,1
919
- 0,0.0010114614,0
920
- 0,0.0013267859,0
921
- 0,0.0008793849,0
922
- 0,0.0013299336,0
923
- 0,0.08014288,0
924
- 0,0.0035911172,0
925
- 0,0.0009812173,0
926
- 1,0.25576597,0
927
- 1,0.7145252,1
928
- 1,0.74173844,1
929
- 1,0.7591139,1
930
- 1,0.79900056,1
931
- 0,0.0010525746,0
932
- 0,0.0010317984,0
933
- 0,0.0009139964,0
934
- 0,0.0011968515,0
935
- 0,0.0010743748,0
936
- 0,0.46601886,0
937
- 0,0.0010757077,0
938
- 0,0.00852641,0
939
- 0,0.002048128,0
940
- 1,0.80355567,1
941
- 0,0.0014045321,0
942
- 0,0.0011115061,0
943
- 0,0.0008338689,0
944
- 1,0.5675804,1
945
- 0,0.0008722262,0
946
- 0,0.001032084,0
947
- 0,0.00090455357,0
948
- 0,0.000761143,0
949
- 1,0.25547284,0
950
- 0,0.72657347,1
951
- 0,0.00095886085,0
952
- 0,0.0007715649,0
953
- 0,0.00086148566,0
954
- 0,0.0017338935,0
955
- 0,0.0029073667,0
956
- 1,0.74649566,1
957
- 0,0.00082679826,0
958
- 0,0.00097232254,0
959
- 0,0.001773119,0
960
- 0,0.0009934949,0
961
- 1,0.8185504,1
962
- 0,0.0017876541,0
963
- 0,0.0008482355,0
964
- 0,0.00088651665,0
965
- 0,0.0028928719,0
966
- 0,0.000960697,0
967
- 0,0.6135191,1
968
- 0,0.0015025416,0
969
- 0,0.00096636615,0
970
- 1,0.27818596,0
971
- 0,0.0010354616,0
972
- 1,0.7740497,1
973
- 0,0.0010154156,0
974
- 0,0.0010276875,0
975
- 0,0.0013553945,0
976
- 0,0.64635444,1
977
- 0,0.36508304,0
978
- 0,0.001062856,0
979
- 0,0.00074448035,0
980
- 0,0.0013278496,0
981
- 0,0.43072903,0
982
- 0,0.3727711,0
983
- 0,0.1476106,0
984
- 1,0.7469999,1
985
- 1,0.5393194,1
986
- 1,0.62362605,1
987
- 0,0.4977242,0
988
- 0,0.0019504097,0
989
- 0,0.0011121453,0
990
- 0,0.00094129675,0
991
- 1,0.13999684,0
992
- 1,0.84094584,1
993
- 1,0.67023414,1
994
- 0,0.0009355351,0
995
- 0,0.0012011972,0
996
- 1,0.85856575,1
997
- 0,0.00095948775,0
998
- 0,0.00092585845,0
999
- 1,0.78414935,1
1000
- 0,0.2081838,0
1001
- 1,0.3134156,0
1002
- 0,0.5827518,1
1003
- 0,0.0011884035,0
1004
- 0,0.3416024,0
1005
- 0,0.3513188,0
1006
- 1,0.7360253,1
1007
- 0,0.14421782,0
1008
- 0,0.0008775974,0
1009
- 0,0.0010155869,0
1010
- 1,0.6316725,1
1011
- 0,0.0011462267,0
1012
- 0,0.00090081355,0
1013
- 0,0.46972373,0
1014
- 0,0.0010553041,0
1015
- 0,0.5809954,1
1016
- 0,0.0059841666,0
1017
- 1,0.6962589,1
1018
- 0,0.004652636,0
1019
- 0,0.0008641569,0
1020
- 1,0.80275476,1
1021
- 0,0.0008407392,0
1022
- 0,0.0010207775,0
1023
- 0,0.0011965961,0
1024
- 1,0.41581193,0
1025
- 0,0.002618945,0
1026
- 0,0.001120625,0
1027
- 0,0.00090287067,0
1028
- 1,0.6921951,1
1029
- 0,0.32014215,0
1030
- 1,0.44863924,0
1031
- 0,0.20511761,0
1032
- 0,0.0008183238,0
1033
- 0,0.27403337,0
1034
- 0,0.0014313993,0
1035
- 0,0.52280444,1
1036
- 0,0.00095016713,0
1037
- 0,0.001587931,0
1038
- 0,0.2349104,0
1039
- 1,0.65222037,1
1040
- 0,0.0015776413,0
1041
- 0,0.0012768807,0
1042
- 0,0.0011938995,0
1043
- 1,0.8248594,1
1044
- 1,0.5302688,1
1045
- 0,0.0017132758,0
1046
- 0,0.0008661427,0
1047
- 0,0.0013237718,0
1048
- 0,0.30642387,0
1049
- 0,0.0013227944,0
1050
- 0,0.0007931251,0
1051
- 0,0.000937015,0
1052
- 0,0.0011775857,0
1053
- 0,0.0031580946,0
1054
- 0,0.0017192552,0
1055
- 0,0.00090431585,0
1056
- 0,0.0008876776,0
1057
- 0,0.3237665,0
1058
- 0,0.0011375708,0
1059
- 0,0.713965,1
1060
- 1,0.6401749,1
1061
- 0,0.055340905,0
1062
- 0,0.0008713582,0
1063
- 0,0.0011798014,0
1064
- 0,0.19061498,0
1065
- 0,0.0008910244,0
1066
- 1,0.6702638,1
1067
- 0,0.00078732,0
1068
- 1,0.5069871,1
1069
- 0,0.0010917349,0
1070
- 1,0.30569232,0
1071
- 0,0.0010508688,0
1072
- 0,0.008350478,0
1073
- 0,0.1295904,0
1074
- 0,0.0012138723,0
1075
- 0,0.23459788,0
1076
- 0,0.00095988956,0
1077
- 0,0.0011302889,0
1078
- 0,0.41533685,0
1079
- 1,0.81494963,1
1080
- 1,0.20318276,0
1081
- 1,0.6849259,1
1082
- 0,0.0007971484,0
1083
- 0,0.0010026495,0
1084
- 1,0.53866345,1
1085
- 0,0.47900033,0
1086
- 0,0.0013153972,0
1087
- 0,0.00092683197,0
1088
- 0,0.00079303206,0
1089
- 0,0.0010194737,0
1090
- 0,0.00090722676,0
1091
- 0,0.0014398387,0
1092
- 0,0.0008348118,0
1093
- 1,0.8690826,1
1094
- 0,0.45386383,0
1095
- 0,0.0009348062,0
1096
- 1,0.7537117,1
1097
- 0,0.0009848943,0
1098
- 1,0.8337264,1
1099
- 0,0.0013707685,0
1100
- 1,0.64028686,1
1101
- 0,0.0009233346,0
1102
- 0,0.0011110144,0
1103
- 0,0.00091898517,0
1104
- 0,0.37525868,0
1105
- 0,0.0012479519,0
1106
- 0,0.0011017517,0
1107
- 0,0.23442647,0
1108
- 0,0.001286486,0
1109
- 0,0.0010251929,0
1110
- 0,0.0012779953,0
1111
- 0,0.10511962,0
1112
- 0,0.0013187885,0
1113
- 0,0.0008432262,0
1114
- 0,0.0023270152,0
1115
- 0,0.19635822,0
1116
- 0,0.42605725,0
1117
- 0,0.0026588596,0
1118
- 0,0.0010699191,0
1119
- 0,0.00083928637,0
1120
- 0,0.0010421542,0
1121
- 0,0.0010161884,0
1122
- 0,0.0007994666,0
1123
- 0,0.0011844927,0
1124
- 1,0.51526463,1
1125
- 1,0.8424703,1
1126
- 0,0.0008336365,0
1127
- 0,0.0008402345,0
1128
- 0,0.0010367532,0
1129
- 0,0.0008751524,0
1130
- 0,0.0013134012,0
1131
- 0,0.0008601877,0
1132
- 0,0.28373632,0
1133
- 0,0.0010632672,0
1134
- 0,0.001001677,0
1135
- 0,0.0009289228,0
1136
- 0,0.0010856837,0
1137
- 0,0.0009986022,0
1138
- 0,0.0011016888,0
1139
- 0,0.0009970806,0
1140
- 0,0.0014796017,0
1141
- 0,0.0009895341,0
1142
- 1,0.87620765,1
1143
- 0,0.13668354,0
1144
- 0,0.5543678,1
1145
- 0,0.0011205432,0
1146
- 1,0.60824937,1
1147
- 1,0.8209723,1
1148
- 1,0.77438575,1
1149
- 0,0.44236442,0
1150
- 1,0.5708855,1
1151
- 0,0.0014770482,0
1152
- 0,0.0011481309,0
1153
- 1,0.76328576,1
1154
- 0,0.0012447835,0
1155
- 0,0.6776131,1
1156
- 0,0.22794476,0
1157
- 0,0.0015239945,0
1158
- 0,0.0014258837,0
1159
- 0,0.2854073,0
1160
- 0,0.0008883935,0
1161
- 1,0.57807,1
1162
- 0,0.0011267756,0
1163
- 1,0.61485714,1
1164
- 0,0.0007958319,0
1165
- 0,0.0010823915,0
1166
- 0,0.2250206,0
1167
- 1,0.4204774,0
1168
- 0,0.00093722483,0
1169
- 0,0.0017883043,0
1170
- 0,0.0009139735,0
1171
- 0,0.00089569465,0
1172
- 0,0.0013786783,0
1173
- 0,0.63649786,1
1174
- 0,0.00097638293,0
1175
- 1,0.73162854,1
1176
- 0,0.00088065065,0
1177
- 0,0.0010577292,0
1178
- 0,0.0011432399,0
1179
- 1,0.6012633,1
1180
- 0,0.0011757269,0
1181
- 0,0.0008843888,0
1182
- 0,0.00094909506,0
1183
- 0,0.0009065131,0
1184
- 1,0.50341845,1
1185
- 1,0.3092106,0
1186
- 0,0.087408595,0
1187
- 1,0.47898576,0
1188
- 0,0.7805265,1
1189
- 0,0.55192417,1
1190
- 0,0.100362,0
1191
- 0,0.0009126828,0
1192
- 1,0.7031464,1
1193
- 0,0.0008378504,0
1194
- 0,0.000909575,0
1195
- 0,0.14698155,0
1196
- 0,0.000886812,0
1197
- 0,0.0034394178,0
1198
- 0,0.00083480857,0
1199
- 0,0.0009230623,0
1200
- 0,0.00086228485,0
1201
- 0,0.0012248036,0
1202
- 0,0.3685421,0
1203
- 1,0.75045246,1
1204
- 1,0.29140982,0
1205
- 1,0.824118,1
1206
- 0,0.0011165916,0
1207
- 0,0.0015328506,0
1208
- 0,0.49720347,0
1209
- 1,0.6922905,1
1210
- 0,0.0007836441,0
1211
- 0,0.000927602,0
1212
- 0,0.70492715,1
1213
- 0,0.0008786006,0
1214
- 0,0.7033895,1
1215
- 1,0.6138756,1
1216
- 1,0.5470745,1
1217
- 1,0.28883335,0
1218
- 0,0.001316554,0
1219
- 0,0.000900596,0
1220
- 1,0.77885973,1
1221
- 0,0.3792245,0
1222
- 0,0.0010777283,0
1223
- 0,0.022616526,0
1224
- 0,0.3294621,0
1225
- 0,0.00081920385,0
1226
- 0,0.0011605065,0
1227
- 0,0.2769352,0
1228
- 0,0.80774456,1
1229
- 0,0.000816923,0
1230
- 0,0.0010059879,0
1231
- 0,0.0009761908,0
1232
- 0,0.0019736588,0
1233
- 0,0.001108255,0
1234
- 0,0.00084299047,0
1235
- 0,0.0007932735,0
1236
- 0,0.0010224554,0
1237
- 0,0.0010019916,0
1238
- 0,0.0011584796,0
1239
- 0,0.00095932867,0
1240
- 0,0.0014852965,0
1241
- 0,0.0011517362,0
1242
- 0,0.0016003981,0
1243
- 0,0.5883149,1
1244
- 0,0.0011397423,0
1245
- 0,0.54703045,1
1246
- 0,0.46631202,0
1247
- 1,0.72806704,1
1248
- 1,0.66733813,1
1249
- 0,0.00082930306,0
1250
- 0,0.0013221847,0
1251
- 0,0.37714672,0
1252
- 1,0.6671186,1
1253
- 0,0.76171786,1
1254
- 1,0.84557354,1
1255
- 0,0.0009865002,0
1256
- 0,0.00078149413,0
1257
- 0,0.0020528194,0
1258
- 0,0.001968213,0
1259
- 1,0.29894271,0
1260
- 1,0.65170336,1
1261
- 0,0.00087412616,0
1262
- 0,0.0008334153,0
1263
- 0,0.002001055,0
1264
- 0,0.0010972196,0
1265
- 1,0.6604654,1
1266
- 1,0.75812054,1
1267
- 1,0.69461435,1
1268
- 0,0.20077878,0
1269
- 0,0.19034809,0
1270
- 0,0.001009536,0
1271
- 0,0.09553723,0
1272
- 1,0.4960136,0
1273
- 0,0.0012337598,0
1274
- 0,0.0030067663,0
1275
- 0,0.53967786,1
1276
- 0,0.0012488324,0
1277
- 0,0.001111368,0
1278
- 0,0.001234995,0
1279
- 0,0.0011818678,0
1280
- 0,0.42678392,0
1281
- 1,0.73771423,1
1282
- 0,0.7895419,1
1283
- 0,0.0012994151,0
1284
- 0,0.00084598584,0
1285
- 0,0.0007987745,0
1286
- 0,0.0012115806,0
1287
- 1,0.58688515,1
1288
- 0,0.00091692246,0
1289
- 0,0.0010023817,0
1290
- 1,0.7862743,1
1291
- 1,0.25347495,0
1292
- 0,0.3326281,0
1293
- 0,0.3226471,0
1294
- 0,0.6390405,1
1295
- 0,0.54111165,1
1296
- 0,0.001012588,0
1297
- 0,0.0009537402,0
1298
- 0,0.0008496503,0
1299
- 1,0.6967211,1
1300
- 0,0.0010966522,0
1301
- 0,0.0011798083,0
1302
- 0,0.6963768,1
1303
- 0,0.00095325924,0
1304
- 0,0.0011999392,0
1305
- 0,0.32601213,0
1306
- 1,0.5316113,1
1307
- 0,0.46846902,0
1308
- 0,0.00096525403,0
1309
- 0,0.0011292063,0
1310
- 0,0.001140362,0
1311
- 0,0.0012218289,0
1312
- 0,0.40930805,0
1313
- 0,0.0011080602,0
1314
- 1,0.5122479,1
1315
- 1,0.59998643,1
1316
- 0,0.061730716,0
1317
- 1,0.41931114,0
1318
- 0,0.16193198,0
1319
- 0,0.0007972308,0
1320
- 0,0.0009665351,0
1321
- 0,0.0010098865,0
1322
- 0,0.0020258357,0
1323
- 0,0.46376464,0
1324
- 0,0.28660434,0
1325
- 0,0.0010756479,0
1326
- 0,0.5611161,1
1327
- 0,0.1110538,0
1328
- 0,0.0015265195,0
1329
- 0,0.00088010135,0
1330
- 0,0.23574975,0
1331
- 0,0.42993984,0
1332
- 0,0.0012539547,0
1333
- 0,0.39793822,0
1334
- 0,0.0008850873,0
1335
- 0,0.0010175374,0
1336
- 0,0.57915115,1
1337
- 0,0.60442656,1
1338
- 0,0.0008875359,0
1339
- 0,0.0010697005,0
1340
- 0,0.0009562333,0
1341
- 0,0.0019253149,0
1342
- 1,0.77746564,1
1343
- 0,0.0015686869,0
1344
- 1,0.31335396,0
1345
- 0,0.000908781,0
1346
- 0,0.0013533514,0
1347
- 0,0.2515571,0
1348
- 0,0.00089592073,0
1349
- 1,0.25762314,0
1350
- 1,0.6580362,1
1351
- 0,0.000960752,0
1352
- 0,0.0010607035,0
1353
- 0,0.0077122697,0
1354
- 1,0.6530093,1
1355
- 0,0.00087614363,0
1356
- 0,0.007246284,0
1357
- 0,0.6812245,1
1358
- 0,0.23485817,0
1359
- 0,0.00086641597,0
1360
- 0,0.0012842317,0
1361
- 0,0.0007809932,0
1362
- 1,0.5749303,1
1363
- 0,0.0015573403,0
1364
- 0,0.13056204,0
1365
- 0,0.357255,0
1366
- 0,0.34036377,0
1367
- 0,0.3980148,0
1368
- 1,0.69733506,1
1369
- 0,0.25944367,0
1370
- 0,0.34873602,0
1371
- 0,0.0009726848,0
1372
- 1,0.49505424,0
1373
- 1,0.6892435,1
1374
- 1,0.7743485,1
1375
- 0,0.32794416,0
1376
- 1,0.22797777,0
1377
- 0,0.0009458529,0
1378
- 0,0.38203833,0
1379
- 0,0.0012890919,0
1380
- 0,0.000789744,0
1381
- 0,0.31899726,0
1382
- 0,0.6728214,1
1383
- 0,0.000999975,0
1384
- 0,0.00094006164,0
1385
- 0,0.00075543154,0
1386
- 0,0.0013109997,0
1387
- 0,0.26060212,0
1388
- 0,0.2962978,0
1389
- 1,0.81047857,1
1390
- 0,0.0009169193,0
1391
- 0,0.0009614863,0
1392
- 0,0.00097122806,0
1393
- 0,0.0015418291,0
1394
- 0,0.0010950873,0
1395
- 0,0.0043601934,0
1396
- 1,0.6357652,1
1397
- 0,0.001165279,0
1398
- 0,0.0009673975,0
1399
- 1,0.83262265,1
1400
- 0,0.0009282163,0
1401
- 0,0.70172244,1
1402
- 0,0.0009607166,0
1403
- 0,0.00084028306,0
1404
- 0,0.6203581,1
1405
- 0,0.2677468,0
1406
- 1,0.7861525,1
1407
- 1,0.76065445,1
1408
- 0,0.0022763049,0
1409
- 0,0.0008973102,0
1410
- 0,0.0008466766,0
1411
- 0,0.0009204096,0
1412
- 1,0.5665215,1
1413
- 0,0.50933987,1
1414
- 0,0.0011783128,0
1415
- 0,0.00090577826,0
1416
- 0,0.47160682,0
1417
- 1,0.86713326,1
1418
- 0,0.001341303,0
1419
- 1,0.16438311,0
1420
- 1,0.5899627,1
1421
- 0,0.0009171007,0
1422
- 0,0.0009108268,0
1423
- 0,0.00084698235,0
1424
- 0,0.0007781086,0
1425
- 0,0.0010233963,0
1426
- 0,0.24456206,0
1427
- 0,0.0014641931,0
1428
- 0,0.0007724079,0
1429
- 0,0.35753417,0
1430
- 0,0.000865038,0
1431
- 0,0.0031078586,0
1432
- 0,0.0008967957,0
1433
- 1,0.77871156,1
1434
- 1,0.58639294,1
1435
- 0,0.0026566426,0
1436
- 0,0.00082757784,0
1437
- 0,0.0011661336,0
1438
- 0,0.001019244,0
1439
- 1,0.7419938,1
1440
- 1,0.42898682,0
1441
- 1,0.76121277,1
1442
- 0,0.0012546297,0
1443
- 1,0.37173283,0
1444
- 1,0.60157835,1
1445
- 1,0.3550261,0
1446
- 0,0.00083142443,0
1447
- 0,0.010433526,0
1448
- 0,0.0008385298,0
1449
- 0,0.0010685887,0
1450
- 0,0.0014192335,0
1451
- 0,0.00084556115,0
1452
- 0,0.25132394,0
1453
- 0,0.00080387027,0
1454
- 0,0.0009537752,0
1455
- 1,0.77484345,1
1456
- 0,0.0010258998,0
1457
- 0,0.00085469737,0
1458
- 1,0.80244,1
1459
- 0,0.6002382,1
1460
- 0,0.00085503253,0
1461
- 0,0.00087335554,0
1462
- 0,0.1733697,0
1463
- 0,0.0013072107,0
1464
- 0,0.31264597,0
1465
- 0,0.82034665,1
1466
- 1,0.67145926,1
1467
- 1,0.5704474,1
1468
- 0,0.0019223319,0
1469
- 0,0.0009967473,0
1470
- 0,0.4709534,0
1471
- 0,0.0044812486,0
1472
- 0,0.0009170116,0
1473
- 1,0.7392076,1
1474
- 0,0.0061432202,0
1475
- 0,0.45998988,0
1476
- 0,0.0033909979,0
1477
- 0,0.0011331846,0
1478
- 0,0.0010110885,0
1479
- 0,0.0011058631,0
1480
- 1,0.41766933,0
1481
- 0,0.00071926217,0
1482
- 0,0.54065317,1
1483
- 1,0.34456867,0
1484
- 0,0.0010258704,0
1485
- 0,0.0016063552,0
1486
- 0,0.3136817,0
1487
- 0,0.0009535526,0
1488
- 0,0.0010021557,0
1489
- 0,0.0009705446,0
1490
- 0,0.14366324,0
1491
- 0,0.11577339,0
1492
- 1,0.5175177,1
1493
- 0,0.0010051089,0
1494
- 0,0.0010658188,0
1495
- 0,0.00097995,0
1496
- 0,0.0011440237,0
1497
- 1,0.54906136,1
1498
- 0,0.4807756,0
1499
- 0,0.300868,0
1500
- 0,0.0011075152,0
1501
- 0,0.00085573347,0
1502
- 0,0.00085721305,0
1503
- 0,0.0010270772,0
1504
- 1,0.75990576,1
1505
- 0,0.06632882,0
1506
- 0,0.001487759,0
1507
- 1,0.20113204,0
1508
- 0,0.0013309364,0
1509
- 0,0.000907346,0
1510
- 1,0.5900306,1
1511
- 0,0.000946925,0
1512
- 0,0.0010906772,0
1513
- 0,0.41585302,0
1514
- 1,0.78238595,1
1515
- 0,0.0009065691,0
1516
- 0,0.0008776255,0
1517
- 1,0.788677,1
1518
- 0,0.001570436,0
1519
- 0,0.22623111,0
1520
- 0,0.0011954956,0
1521
- 0,0.26859984,0
1522
- 0,0.0012529476,0
1523
- 0,0.00084351,0
1524
- 0,0.0009662851,0
1525
- 0,0.0008385861,0
1526
- 0,0.0011187536,0
1527
- 0,0.000936137,0
1528
- 0,0.0012308148,0
1529
- 0,0.00084265514,0
1530
- 1,0.55672204,1
1531
- 1,0.7771148,1
1532
- 0,0.0012370128,0
1533
- 0,0.0010973206,0
1534
- 0,0.00087856123,0
1535
- 0,0.11740843,0
1536
- 0,0.0010494886,0
1537
- 0,0.0016682858,0
1538
- 0,0.0012532771,0
1539
- 0,0.001405361,0
1540
- 0,0.20851627,0
1541
- 1,0.6311097,1
1542
- 0,0.0008749682,0
1543
- 1,0.6411818,1
1544
- 1,0.83963007,1
1545
- 0,0.0011985023,0
1546
- 0,0.0010454545,0
1547
- 0,0.7701166,1
1548
- 0,0.00076997356,0
1549
- 1,0.8383424,1
1550
- 1,0.6781555,1
1551
- 0,0.0036841896,0
1552
- 1,0.59758455,1
1553
- 0,0.043098766,0
1554
- 0,0.0008827312,0
1555
- 0,0.00080906873,0
1556
- 1,0.78783256,1
1557
- 0,0.0011804759,0
1558
- 0,0.46424973,0
1559
- 0,0.0017560824,0
1560
- 0,0.004443426,0
1561
- 0,0.29807636,0
1562
- 0,0.0010315041,0
1563
- 0,0.5789729,1
1564
- 0,0.00084546005,0
1565
- 0,0.0011440341,0
1566
- 1,0.81396484,1
1567
- 1,0.56533325,1
1568
- 0,0.0011488326,0
1569
- 0,0.0012227881,0
1570
- 0,0.25655642,0
1571
- 1,0.5899143,1
1572
- 0,0.54285944,1
1573
- 0,0.0012625692,0
1574
- 0,0.00080501626,0
1575
- 0,0.0011473355,0
1576
- 0,0.0011238786,0
1577
- 0,0.0021833822,0
1578
- 0,0.0008517226,0
1579
- 0,0.0009454461,0
1580
- 0,0.0010944246,0
1581
- 0,0.024547417,0
1582
- 0,0.0010403111,0
1583
- 0,0.26350185,0
1584
- 0,0.0010104905,0
1585
- 0,0.09447297,0
1586
- 0,0.000852578,0
1587
- 0,0.0012083028,0
1588
- 0,0.6784862,1
1589
- 0,0.0009658967,0
1590
- 0,0.0010201805,0
1591
- 0,0.16826008,0
1592
- 0,0.0008753944,0
1593
- 0,0.00078190427,0
1594
- 0,0.20378338,0
1595
- 1,0.6095833,1
1596
- 1,0.55670387,1
1597
- 0,0.47983488,0
1598
- 1,0.24339448,0
1599
- 0,0.0013973623,0
1600
- 0,0.0008691309,0
1601
- 1,0.8703108,1
1602
- 0,0.002405853,0
1603
- 0,0.0011524003,0
1604
- 0,0.55783266,1
1605
- 0,0.0012722318,0
1606
- 0,0.00088787306,0
1607
- 0,0.1720217,0
1608
- 0,0.0009373888,0
1609
- 0,0.0043997974,0
1610
- 1,0.34884313,0
1611
- 1,0.3523087,0
1612
- 0,0.65511626,1
1613
- 0,0.0014629874,0
1614
- 0,0.37225887,0
1615
- 0,0.0012105026,0
1616
- 0,0.0010079421,0
1617
- 0,0.0010033519,0
1618
- 0,0.0008274714,0
1619
- 0,0.00087697804,0
1620
- 0,0.20703182,0
1621
- 0,0.0011603661,0
1622
- 0,0.0014855879,0
1623
- 0,0.22130914,0
1624
- 0,0.00086291565,0
1625
- 0,0.0008673242,0
1626
- 0,0.0011818307,0
1627
- 0,0.00096120656,0
1628
- 0,0.000923244,0
1629
- 0,0.4318089,0
1630
- 1,0.31608683,0
1631
- 1,0.77528465,1
1632
- 0,0.0013540791,0
1633
- 0,0.0008699616,0
1634
- 0,0.00094812346,0
1635
- 0,0.51795197,1
1636
- 1,0.61414665,1
1637
- 1,0.7352273,1
1638
- 0,0.00086347247,0
1639
- 0,0.0008687025,0
1640
- 0,0.0011694061,0
1641
- 0,0.0011693755,0
1642
- 0,0.455578,0
1643
- 0,0.105311655,0
1644
- 0,0.0008285432,0
1645
- 0,0.0010142302,0
1646
- 1,0.19479476,0
1647
- 0,0.0008654564,0
1648
- 0,0.0011336721,0
1649
- 0,0.0011340285,0
1650
- 0,0.24410512,0
1651
- 0,0.00090666034,0
1652
- 0,0.0015996577,0
1653
- 1,0.73918,1
1654
- 0,0.000983292,0
1655
- 0,0.0013622475,0
1656
- 0,0.0012581018,0
1657
- 0,0.00080613553,0
1658
- 0,0.0009217467,0
1659
- 1,0.26258734,0
1660
- 0,0.0012133989,0
1661
- 0,0.001480057,0
1662
- 0,0.7632006,1
1663
- 1,0.5682584,1
1664
- 0,0.0014863011,0
1665
- 0,0.63781154,1
1666
- 0,0.0011002115,0
1667
- 0,0.00084043556,0
1668
- 0,0.0014428858,0
1669
- 0,0.0013104859,0
1670
- 1,0.76846975,1
1671
- 1,0.6673162,1
1672
- 0,0.0008704569,0
1673
- 0,0.0008219127,0
1674
- 0,0.65152436,1
1675
- 0,0.0008993478,0
1676
- 0,0.0008565709,0
1677
- 0,0.14742456,0
1678
- 0,0.057693627,0
1679
- 0,0.0008922193,0
1680
- 1,0.15329692,0
1681
- 1,0.65068716,1
1682
- 0,0.0009436021,0
1683
- 0,0.009623958,0
1684
- 0,0.14724772,0
1685
- 0,0.0014856155,0
1686
- 0,0.635374,1
1687
- 0,0.53666717,1
1688
- 0,0.24563298,0
1689
- 0,0.46256322,0
1690
- 0,0.0009238498,0
1691
- 0,0.0008537007,0
1692
- 1,0.7335384,1
1693
- 0,0.00082300097,0
1694
- 1,0.444842,0
1695
- 0,0.0010771926,0
1696
- 0,0.00084987056,0
1697
- 0,0.0020735206,0
1698
- 0,0.0008576367,0
1699
- 0,0.00083232333,0
1700
- 0,0.0011556697,0
1701
- 0,0.0015070596,0
1702
- 0,0.0008087288,0
1703
- 0,0.0013067058,0
1704
- 0,0.21971786,0
1705
- 0,0.00081399345,0
1706
- 1,0.5227279,1
1707
- 0,0.0012097977,0
1708
- 0,0.001634093,0
1709
- 1,0.8266393,1
1710
- 1,0.79590076,1
1711
- 0,0.2257322,0
1712
- 1,0.6879368,1
1713
- 0,0.0008921309,0
1714
- 0,0.000969393,0
1715
- 1,0.77856237,1
1716
- 0,0.00078926905,0
1717
- 0,0.001172239,0
1718
- 1,0.40544796,0
1719
- 1,0.69128567,1
1720
- 0,0.0058039543,0
1721
- 0,0.0010676429,0
1722
- 0,0.0009675725,0
1723
- 0,0.24802871,0
1724
- 0,0.00088985654,0
1725
- 0,0.00090884947,0
1726
- 0,0.48389488,0
1727
- 0,0.67111087,1
1728
- 0,0.0011216454,0
1729
- 0,0.000996518,0
1730
- 1,0.7353937,1
1731
- 1,0.6577438,1
1732
- 1,0.7053728,1
1733
- 1,0.78178835,1
1734
- 0,0.615978,1
1735
- 1,0.6272983,1
1736
- 0,0.004934533,0
1737
- 0,0.0024336318,0
1738
- 0,0.00089967996,0
1739
- 1,0.5925518,1
1740
- 1,0.7473727,1
1741
- 0,0.3549951,0
1742
- 0,0.57792014,1
1743
- 1,0.6967723,1
1744
- 0,0.6283576,1
1745
- 0,0.3178549,0
1746
- 1,0.2248403,0
1747
- 1,0.81427705,1
1748
- 0,0.0010010622,0
1749
- 0,0.0013560583,0
1750
- 0,0.0015472178,0
1751
- 0,0.0009804085,0
1752
- 0,0.0007600732,0
1753
- 1,0.5372315,1
1754
- 0,0.0010776622,0
1755
- 0,0.00083396706,0
1756
- 0,0.0013004588,0
1757
- 0,0.5804702,1
1758
- 0,0.35824117,0
1759
- 0,0.001130411,0
1760
- 1,0.45011675,0
1761
- 0,0.0011523701,0
1762
- 0,0.0009201427,0
1763
- 0,0.0009954533,0
1764
- 0,0.00096182036,0
1765
- 0,0.00093874236,0
1766
- 0,0.4260264,0
1767
- 0,0.0012323499,0
1768
- 0,0.0014474612,0
1769
- 0,0.0009775738,0
1770
- 0,0.2297708,0
1771
- 0,0.69282055,1
1772
- 0,0.0009934746,0
1773
- 0,0.36029232,0
1774
- 1,0.36951026,0
1775
- 0,0.40954018,0
1776
- 0,0.00085924973,0
1777
- 0,0.003017331,0
1778
- 0,0.0016876189,0
1779
- 0,0.0010288198,0
1780
- 1,0.7446905,1
1781
- 0,0.24530704,0
1782
- 0,0.0010668626,0
1783
- 0,0.0037525713,0
1784
- 0,0.001123924,0
1785
- 0,0.0013527669,0
1786
- 0,0.30088598,0
1787
- 0,0.22248338,0
1788
- 0,0.0012103878,0
1789
- 0,0.40475878,0
1790
- 0,0.0008372259,0
1791
- 0,0.46901873,0
1792
- 0,0.0012301363,0
1793
- 0,0.0011264655,0
1794
- 0,0.00089167035,0
1795
- 0,0.00086598546,0
1796
- 0,0.0007683753,0
1797
- 0,0.120619036,0
1798
- 1,0.31516576,0
1799
- 0,0.0010315306,0
1800
- 1,0.82660633,1
1801
- 1,0.16573146,0
1802
- 1,0.45056126,0
1803
- 0,0.012582662,0
1804
- 0,0.2679602,0
1805
- 0,0.0007925501,0
1806
- 0,0.0017378323,0
1807
- 0,0.0017654862,0
1808
- 1,0.65987283,1
1809
- 1,0.6453965,1
1810
- 1,0.553943,1
1811
- 0,0.34653622,0
1812
- 0,0.0009486267,0
1813
- 0,0.0010880433,0
1814
- 1,0.64795786,1
1815
- 1,0.8529027,1
1816
- 1,0.537922,1
1817
- 0,0.0017200281,0
1818
- 0,0.0008353453,0
1819
- 1,0.75242555,1
1820
- 0,0.0007492936,0
1821
- 0,0.38131416,0
1822
- 0,0.61270493,1
1823
- 0,0.3288878,0
1824
- 0,0.6910125,1
1825
- 1,0.75235885,1
1826
- 1,0.7090527,1
1827
- 0,0.00084375514,0
1828
- 0,0.7937191,1
1829
- 0,0.00093711726,0
1830
- 1,0.5996378,1
1831
- 0,0.0009235945,0
1832
- 0,0.39941287,0
1833
- 1,0.42442018,0
1834
- 0,0.0010583275,0
1835
- 0,0.60000926,1
1836
- 0,0.0010812476,0
1837
- 0,0.0009984957,0
1838
- 1,0.8072236,1
1839
- 1,0.4488358,0
1840
- 0,0.0015813457,0
1841
- 1,0.2241001,0
1842
- 1,0.75282294,1
1843
- 0,0.000868323,0
1844
- 1,0.6267692,1
1845
- 0,0.00089177827,0
1846
- 0,0.001014311,0
1847
- 0,0.0008509839,0
1848
- 0,0.5767851,1
1849
- 0,0.0007812928,0
1850
- 0,0.0012110077,0
1851
- 0,0.0009440347,0
1852
- 1,0.7179761,1
1853
- 1,0.4313237,0
1854
- 0,0.0010092129,0
1855
- 0,0.00078347983,0
1856
- 0,0.0011101945,0
1857
- 1,0.53682137,1
1858
- 0,0.0009936662,0
1859
- 0,0.0011565915,0
1860
- 0,0.0010442929,0
1861
- 1,0.39957625,0
1862
- 0,0.001340685,0
1863
- 1,0.5603316,1
1864
- 0,0.22706328,0
1865
- 0,0.0009566926,0
1866
- 0,0.43829724,0
1867
- 0,0.0008824684,0
1868
- 0,0.0011820873,0
1869
- 0,0.0008683048,0
1870
- 0,0.00074347574,0
1871
- 0,0.51406115,1
1872
- 0,0.0025957846,0
1873
- 0,0.0010757236,0
1874
- 1,0.59766775,1
1875
- 0,0.0009575653,0
1876
- 0,0.0012813427,0
1877
- 0,0.001484598,0
1878
- 0,0.8911086,1
1879
- 1,0.41940725,0
1880
- 0,0.104963206,0
1881
- 0,0.22381157,0
1882
- 0,0.50649196,1
1883
- 0,0.0010766535,0
1884
- 0,0.00096266886,0
1885
- 0,0.0012597787,0
1886
- 0,0.0008376636,0
1887
- 0,0.0021784822,0
1888
- 0,0.0014312923,0
1889
- 0,0.0009531075,0
1890
- 0,0.0009719756,0
1891
- 1,0.5082326,1
1892
- 1,0.5103978,1
1893
- 0,0.4345131,0
1894
- 0,0.0010050534,0
1895
- 0,0.0008835647,0
1896
- 0,0.0007883947,0
1897
- 0,0.0022930545,0
1898
- 0,0.0029679493,0
1899
- 0,0.0011517105,0
1900
- 0,0.30619615,0
1901
- 0,0.093977414,0
1902
- 0,0.00095720595,0
1903
- 0,0.0015321785,0
1904
- 1,0.63398594,1
1905
- 0,0.000823775,0
1906
- 0,0.00079622905,0
1907
- 0,0.0009634666,0
1908
- 0,0.49582902,0
1909
- 0,0.0009126799,0
1910
- 0,0.0009068175,0
1911
- 0,0.0008108485,0
1912
- 0,0.0016464454,0
1913
- 0,0.17892216,0
1914
- 0,0.0009123811,0
1915
- 0,0.5026859,1
1916
- 0,0.00095499237,0
1917
- 0,0.0007871654,0
1918
- 0,0.001076556,0
1919
- 0,0.0011965865,0
1920
- 1,0.34323403,0
1921
- 0,0.87479687,1
1922
- 0,0.0008861932,0
1923
- 0,0.35567224,0
1924
- 0,0.36455706,0
1925
- 0,0.0010463424,0
1926
- 0,0.13184956,0
1927
- 0,0.001133562,0
1928
- 0,0.2384513,0
1929
- 0,0.002471907,0
1930
- 0,0.5236164,1
1931
- 1,0.7201213,1
1932
- 0,0.20469072,0
1933
- 1,0.3963374,0
1934
- 1,0.52494746,1
1935
- 0,0.14653996,0
1936
- 0,0.0008450751,0
1937
- 0,0.075510696,0
1938
- 1,0.23577085,0
1939
- 1,0.2894605,0
1940
- 0,0.0015299466,0
1941
- 0,0.7472723,1
1942
- 1,0.71758413,1
1943
- 1,0.63296187,1
1944
- 0,0.0013299449,0
1945
- 0,0.777783,1
1946
- 1,0.6209929,1
1947
- 0,0.00093404716,0
1948
- 0,0.58170086,1
1949
- 0,0.7959799,1
1950
- 0,0.0010129189,0
1951
- 0,0.0009979403,0
1952
- 1,0.44537392,0
1953
- 0,0.561583,1
1954
- 0,0.000845328,0
1955
- 0,0.0010346215,0
1956
- 1,0.67326033,1
1957
- 0,0.0009218446,0
1958
- 0,0.0010009438,0
1959
- 1,0.7646678,1
1960
- 0,0.6851708,1
1961
- 0,0.34248728,0
1962
- 0,0.0010312623,0
1963
- 0,0.001760763,0
1964
- 0,0.35641682,0
1965
- 0,0.0016838545,0
1966
- 0,0.6281676,1
1967
- 1,0.4627605,0
1968
- 0,0.0015306249,0
1969
- 1,0.23714353,0
1970
- 1,0.71622694,1
1971
- 1,0.7319818,1
1972
- 0,0.00108557,0
1973
- 0,0.5537684,1
1974
- 0,0.0012903535,0
1975
- 0,0.0012236463,0
1976
- 0,0.0012290003,0
1977
- 0,0.0008847146,0
1978
- 0,0.17526501,0
1979
- 1,0.6066298,1
1980
- 1,0.6446647,1
1981
- 0,0.00080044393,0
1982
- 0,0.098615125,0
1983
- 1,0.5616812,1
1984
- 0,0.6933321,1
1985
- 0,0.0011550131,0
1986
- 0,0.0011100766,0
1987
- 0,0.5804321,1
1988
- 0,0.0010168053,0
1989
- 0,0.22010425,0
1990
- 1,0.32657132,0
1991
- 0,0.0009580052,0
1992
- 1,0.5661701,1
1993
- 0,0.8999228,1
1994
- 0,0.0010242854,0
1995
- 0,0.00079186546,0
1996
- 0,0.0008344046,0
1997
- 0,0.006115876,0
1998
- 0,0.6555151,1
1999
- 0,0.47155666,0
2000
- 0,0.0011539217,0
2001
- 0,0.42844838,0
2002
- 0,0.00095852744,0
2003
- 0,0.0011793413,0
2004
- 0,0.5727451,1
2005
- 0,0.0011716434,0
2006
- 0,0.017039847,0
2007
- 0,0.0011528861,0
2008
- 0,0.0008173667,0
2009
- 0,0.0068512554,0
2010
- 1,0.752859,1
2011
- 1,0.67884725,1
2012
- 0,0.002062399,0
2013
- 0,0.0008773194,0
2014
- 0,0.4187466,0
2015
- 0,0.0008783964,0
2016
- 0,0.0011804664,0
2017
- 0,0.00090095046,0
2018
- 0,0.0008461568,0
2019
- 0,0.00090098643,0
2020
- 1,0.4443772,0
2021
- 0,0.0012510206,0
2022
- 0,0.0010878003,0
2023
- 0,0.001125162,0
2024
- 0,0.0009025954,0
2025
- 0,0.00082608557,0
2026
- 0,0.0012391479,0
2027
- 0,0.0010274188,0
2028
- 0,0.00085046276,0
2029
- 0,0.00094802276,0
2030
- 1,0.20541348,0
2031
- 1,0.37556043,0
2032
- 0,0.40013596,0
2033
- 0,0.6528374,1
2034
- 0,0.6438984,1
2035
- 0,0.33941498,0
2036
- 1,0.6012913,1
2037
- 1,0.40081245,0
2038
- 0,0.0009239743,0
2039
- 0,0.00088827714,0
2040
- 1,0.69142634,1
2041
- 0,0.00095390016,0
2042
- 0,0.00091235427,0
2043
- 0,0.46778736,0
2044
- 1,0.6588788,1
2045
- 0,0.00085993443,0
2046
- 0,0.00095059595,0
2047
- 0,0.31993032,0
2048
- 1,0.31242043,0
2049
- 1,0.7491844,1
2050
- 0,0.00084422436,0
2051
- 1,0.81296134,1
2052
- 0,0.0009211382,0
2053
- 1,0.45493677,0
2054
- 0,0.0011309453,0
2055
- 1,0.27920976,0
2056
- 0,0.001766247,0
2057
- 1,0.6483464,1
2058
- 0,0.4439611,0
2059
- 0,0.28825438,0
2060
- 0,0.0013631671,0
2061
- 0,0.001132262,0
2062
- 0,0.0013900561,0
2063
- 0,0.0015789722,0
2064
- 0,0.0017141184,0
2065
- 1,0.69817054,1
2066
- 0,0.011650326,0
2067
- 0,0.001025334,0
2068
- 0,0.0011714353,0
2069
- 0,0.0008079835,0
2070
- 1,0.5759009,1
2071
- 0,0.22395323,0
2072
- 1,0.79228,1
2073
- 0,0.0011739928,0
2074
- 0,0.00091123086,0
2075
- 0,0.00080784655,0
2076
- 0,0.0010515592,0
2077
- 0,0.39967638,0
2078
- 0,0.0010145825,0
2079
- 1,0.5423705,1
2080
- 1,0.74459624,1
2081
- 0,0.00082461955,0
2082
- 0,0.0029421886,0
2083
- 0,0.43035188,0
2084
- 0,0.6877102,1
2085
- 0,0.0009364068,0
2086
- 0,0.0012012246,0
2087
- 0,0.000867792,0
2088
- 0,0.15176591,0
2089
- 1,0.77357423,1
2090
- 0,0.0034248335,0
2091
- 0,0.00086901913,0
2092
- 0,0.027418558,0
2093
- 0,0.00096116075,0
2094
- 1,0.6818173,1
2095
- 0,0.0009722299,0
2096
- 1,0.8720049,1
2097
- 0,0.00084275793,0
2098
- 0,0.015429355,0
2099
- 0,0.0009900979,0
2100
- 0,0.0008618162,0
2101
- 0,0.0010659914,0
2102
- 0,0.0011133268,0
2103
- 1,0.6155381,1
2104
- 0,0.0019113526,0
2105
- 0,0.0013318051,0
2106
- 0,0.0021403905,0
2107
- 0,0.0013350251,0
2108
- 0,0.0011662951,0
2109
- 0,0.0009040375,0
2110
- 0,0.0029542346,0
2111
- 0,0.0026244,0
2112
- 0,0.20788012,0
2113
- 0,0.5909254,1
2114
- 0,0.0010929306,0
2115
- 0,0.08327696,0
2116
- 0,0.001310451,0
2117
- 0,0.14910121,0
2118
- 0,0.0011818194,0
2119
- 0,0.001830041,0
2120
- 1,0.69931924,1
2121
- 0,0.0028769197,0
2122
- 0,0.38193005,0
2123
- 0,0.0017126591,0
2124
- 0,0.0008859742,0
2125
- 0,0.41827688,0
2126
- 0,0.0011043509,0
2127
- 0,0.27034986,0
2128
- 0,0.0010202782,0
2129
- 0,0.0010163564,0
2130
- 1,0.85359937,1
2131
- 0,0.001006123,0
2132
- 0,0.00084935484,0
2133
- 0,0.011224475,0
2134
- 0,0.0010811862,0
2135
- 0,0.0010717752,0
2136
- 0,0.3722477,0
2137
- 0,0.0010041799,0
2138
- 0,0.0022786073,0
2139
- 0,0.0010501673,0
2140
- 0,0.44608837,0
2141
- 0,0.00079554186,0
2142
- 0,0.49635226,0
2143
- 0,0.4600765,0
2144
- 1,0.31137487,0
2145
- 0,0.0009268139,0
2146
- 0,0.001196574,0
2147
- 0,0.001026193,0
2148
- 0,0.0009154746,0
2149
- 1,0.33101282,0
2150
- 0,0.0011979094,0
2151
- 0,0.0014616075,0
2152
- 0,0.0014142476,0
2153
- 1,0.8949496,1
2154
- 1,0.8674389,1
2155
- 0,0.0008275794,0
2156
- 0,0.00087104103,0
2157
- 0,0.0010423438,0
2158
- 0,0.31495747,0
2159
- 0,0.0014686183,0
2160
- 1,0.67044705,1
2161
- 0,0.2931136,0
2162
- 0,0.04735578,0
2163
- 0,0.09337471,0
2164
- 0,0.0009415873,0
2165
- 0,0.0010099063,0
2166
- 0,0.53362274,1
2167
- 0,0.00096574856,0
2168
- 0,0.0012844005,0
2169
- 0,0.00083572033,0
2170
- 1,0.8030804,1
2171
- 0,0.0010202461,0
2172
- 0,0.0013574356,0
2173
- 1,0.6834372,1
2174
- 0,0.09290696,0
2175
- 0,0.0009394532,0
2176
- 0,0.0009227664,0
2177
- 1,0.50783664,1
2178
- 1,0.8649702,1
2179
- 1,0.7761252,1
2180
- 0,0.0012978833,0
2181
- 0,0.002984829,0
2182
- 0,0.0010553977,0
2183
- 0,0.42205676,0
2184
- 1,0.77211195,1
2185
- 0,0.001033883,0
2186
- 0,0.63332963,1
2187
- 0,0.0014111141,0
2188
- 0,0.0010597117,0
2189
- 1,0.26766336,0
2190
- 0,0.3339755,0
2191
- 0,0.16533276,0
2192
- 1,0.6683201,1
2193
- 0,0.0021313336,0
2194
- 0,0.001062777,0
2195
- 1,0.58701336,1
2196
- 0,0.000942338,0
2197
- 0,0.16852604,0
2198
- 0,0.6969176,1
2199
- 0,0.001123046,0
2200
- 1,0.6442029,1
2201
- 0,0.0031450344,0
2202
- 0,0.0009817418,0
2203
- 0,0.00335572,0
2204
- 0,0.001056942,0
2205
- 1,0.59806097,1
2206
- 1,0.79247564,1
2207
- 0,0.0009922732,0
2208
- 0,0.46785074,0
2209
- 0,0.22426423,0
2210
- 0,0.41550258,0
2211
- 0,0.0008828788,0
2212
- 0,0.0008916164,0
2213
- 1,0.5482998,1
2214
- 0,0.0011476376,0
2215
- 0,0.8011317,1
2216
- 0,0.000926976,0
2217
- 0,0.0009660738,0
2218
- 0,0.06638102,0
2219
- 0,0.0012834723,0
2220
- 0,0.00091870327,0
2221
- 0,0.0008933465,0
2222
- 0,0.0031780212,0
2223
- 0,0.0016238751,0
2224
- 0,0.3595388,0
2225
- 0,0.6696774,1
2226
- 0,0.40328088,0
2227
- 0,0.013335245,0
2228
- 0,0.0010570795,0
2229
- 0,0.0013795571,0
2230
- 0,0.16846372,0
2231
- 1,0.84474844,1
2232
- 0,0.0008750753,0
2233
- 0,0.00087973947,0
2234
- 1,0.50548005,1
2235
- 0,0.0008324304,0
2236
- 0,0.00094381167,0
2237
- 0,0.00089194905,0
2238
- 0,0.5607766,1
2239
- 1,0.54241484,1
2240
- 0,0.00443618,0
2241
- 0,0.36638293,0
2242
- 0,0.7677964,1
2243
- 0,0.37992284,0
2244
- 1,0.70686316,1
2245
- 0,0.5411161,1
2246
- 1,0.34657082,0
2247
- 1,0.42200235,0
2248
- 0,0.0010511446,0
2249
- 0,0.0011092986,0
2250
- 0,0.10377656,0
2251
- 0,0.14405686,0
2252
- 0,0.001046122,0
2253
- 0,0.0011387641,0
2254
- 0,0.054750968,0
2255
- 1,0.84262145,1
2256
- 0,0.0011182426,0
2257
- 0,0.70416677,1
2258
- 0,0.001527749,0
2259
- 1,0.39897177,0
2260
- 0,0.0014704145,0
2261
- 0,0.00088706374,0
2262
- 0,0.0008660363,0
2263
- 0,0.0020009428,0
2264
- 1,0.7243906,1
2265
- 0,0.0011177352,0
2266
- 1,0.58792394,1
2267
- 0,0.0012186854,0
2268
- 0,0.000997782,0
2269
- 1,0.50993186,1
2270
- 1,0.08566585,0
2271
- 0,0.0009616673,0
2272
- 0,0.3577684,0
2273
- 0,0.00096860697,0
2274
- 0,0.446375,0
2275
- 0,0.0008482884,0
2276
- 0,0.0009922902,0
2277
- 0,0.6955276,1
2278
- 0,0.5420666,1
2279
- 0,0.0011368056,0
2280
- 0,0.000882534,0
2281
- 0,0.0009420459,0
2282
- 0,0.0014183329,0
2283
- 0,0.000876334,0
2284
- 0,0.00083285803,0
2285
- 0,0.0013245676,0
2286
- 0,0.00087310805,0
2287
- 0,0.0013295789,0
2288
- 0,0.6923525,1
2289
- 0,0.0010767679,0
2290
- 0,0.005409557,0
2291
- 0,0.0009981047,0
2292
- 0,0.015847243,0
2293
- 0,0.0013694075,0
2294
- 0,0.0007474202,0
2295
- 1,0.506777,1
2296
- 0,0.009123179,0
2297
- 0,0.49657914,0
2298
- 0,0.277546,0
2299
- 0,0.16679445,0
2300
- 0,0.0013762934,0
2301
- 0,0.0011636585,0
2302
- 0,0.0101664085,0
2303
- 0,0.0010100742,0
2304
- 0,0.00094104477,0
2305
- 0,0.32086867,0
2306
- 0,0.75031704,1
2307
- 0,0.0014099678,0
2308
- 0,0.0013869625,0
2309
- 0,0.0016245442,0
2310
- 0,0.0010221949,0
2311
- 1,0.49876264,0
2312
- 1,0.38857585,0
2313
- 0,0.0020588983,0
2314
- 0,0.0015847568,0
2315
- 1,0.6566301,1
2316
- 0,0.00082876394,0
2317
- 0,0.0008605746,0
2318
- 0,0.7650169,1
2319
- 1,0.34104934,0
2320
- 0,0.68121535,1
2321
- 0,0.0011083046,0
2322
- 0,0.0019450967,0
2323
- 1,0.51100874,1
2324
- 0,0.0011186971,0
2325
- 0,0.5587132,1
2326
- 0,0.00088953995,0
2327
- 0,0.0011756523,0
2328
- 1,0.65532446,1
2329
- 0,0.0009612486,0
2330
- 0,0.0021935876,0
2331
- 1,0.5830356,1
2332
- 0,0.0012566766,0
2333
- 1,0.60258913,1
2334
- 0,0.0011863246,0
2335
- 0,0.0009032179,0
2336
- 0,0.65853,1
2337
- 0,0.000863914,0
2338
- 0,0.0010347526,0
2339
- 1,0.7523419,1
2340
- 1,0.72493607,1
2341
- 0,0.0007886984,0
2342
- 0,0.0011739223,0
2343
- 0,0.0030585844,0
2344
- 0,0.0012207521,0
2345
- 0,0.00080916967,0
2346
- 1,0.64112777,1
2347
- 0,0.52794874,1
2348
- 0,0.0010475264,0
2349
- 0,0.6387218,1
2350
- 0,0.0008685981,0
2351
- 0,0.0009589503,0
2352
- 1,0.55499166,1
2353
- 0,0.44648212,0
2354
- 0,0.03675145,0
2355
- 0,0.2166356,0
2356
- 0,0.0012044515,0
2357
- 0,0.0013222471,0
2358
- 0,0.4260182,0
2359
- 0,0.2610802,0
2360
- 1,0.19833244,0
2361
- 0,0.00086645305,0
2362
- 1,0.5524298,1
2363
- 0,0.0009163962,0
2364
- 1,0.85790485,1
2365
- 0,0.0013663898,0
2366
- 0,0.0014356149,0
2367
- 1,0.6442424,1
2368
- 0,0.71102035,1
2369
- 1,0.5645003,1
2370
- 0,0.0012521823,0
2371
- 0,0.00092533894,0
2372
- 0,0.0018732402,0
2373
- 1,0.54946786,1
2374
- 0,0.0012147287,0
2375
- 0,0.0019167685,0
2376
- 0,0.0019650552,0
2377
- 0,0.0010118411,0
2378
- 0,0.001942437,0
2379
- 1,0.74532837,1
2380
- 0,0.73756206,1
2381
- 1,0.6251341,1
2382
- 0,0.27416506,0
2383
- 1,0.50981134,1
2384
- 0,0.5543468,1
2385
- 0,0.0009716256,0
2386
- 0,0.0007795425,0
2387
- 0,0.0019102368,0
2388
- 0,0.0013252186,0
2389
- 0,0.0011993565,0
2390
- 0,0.4715206,0
2391
- 0,0.060475998,0
2392
- 0,0.0010275397,0
2393
- 0,0.09666838,0
2394
- 0,0.00092452514,0
2395
- 0,0.00089867495,0
2396
- 0,0.00094824354,0
2397
- 0,0.59952736,1
2398
- 1,0.3576,0
2399
- 0,0.0012081462,0
2400
- 0,0.56758493,1
2401
- 1,0.6091145,1
2402
- 0,0.0028927878,0
2403
- 0,0.00090088655,0
2404
- 0,0.0022922496,0
2405
- 0,0.0008741644,0
2406
- 0,0.0010697065,0
2407
- 0,0.23075378,0
2408
- 0,0.2533621,0
2409
- 1,0.7158058,1
2410
- 0,0.00092310895,0
2411
- 0,0.0008439115,0
2412
- 0,0.29364723,0
2413
- 0,0.3924025,0
2414
- 0,0.00081181526,0
2415
- 0,0.0016517383,0
2416
- 1,0.65525025,1
2417
- 0,0.0009397724,0
2418
- 0,0.0010028163,0
2419
- 0,0.0010158733,0
2420
- 1,0.5670775,1
2421
- 0,0.00096147077,0
2422
- 0,0.6200651,1
2423
- 0,0.0022722506,0
2424
- 1,0.708499,1
2425
- 0,0.335328,0
2426
- 0,0.0012308101,0
2427
- 0,0.52159727,1
2428
- 0,0.68693334,1
2429
- 1,0.78105223,1
2430
- 0,0.146075,0
2431
- 0,0.22292003,0
2432
- 0,0.4519507,0
2433
- 0,0.00090513856,0
2434
- 1,0.7193635,1
2435
- 0,0.0008796326,0
2436
- 0,0.5841723,1
2437
- 0,0.00097607274,0
2438
- 0,0.49114498,0
2439
- 0,0.32498598,0
2440
- 0,0.0007685969,0
2441
- 0,0.0009030518,0
2442
- 0,0.0010781017,0
2443
- 0,0.0008947887,0
2444
- 0,0.00097482925,0
2445
- 0,0.0018749419,0
2446
- 1,0.81562716,1
2447
- 0,0.4467036,0
2448
- 0,0.0012841815,0
2449
- 1,0.8361843,1
2450
- 0,0.0028613487,0
2451
- 0,0.0009700805,0
2452
- 0,0.00091441197,0
2453
- 0,0.0011520925,0
2454
- 0,0.35427848,0
2455
- 0,0.0009871047,0
2456
- 0,0.0010851595,0
2457
- 0,0.0011901299,0
2458
- 1,0.68657166,1
2459
- 0,0.0008210955,0
2460
- 0,0.3926858,0
2461
- 0,0.0008118559,0
2462
- 0,0.001132336,0
2463
- 0,0.47707522,0
2464
- 0,0.0010266873,0
2465
- 1,0.49586618,0
2466
- 0,0.49493933,0
2467
- 0,0.0007915444,0
2468
- 0,0.0015369055,0
2469
- 0,0.001586034,0
2470
- 1,0.3755337,0
2471
- 0,0.42481285,0
2472
- 0,0.023645595,0
2473
- 0,0.0010042259,0
2474
- 0,0.00128545,0
2475
- 1,0.86342037,1
2476
- 0,0.52136,1
2477
- 0,0.00090365595,0
2478
- 0,0.00097448146,0
2479
- 0,0.109581724,0
2480
- 0,0.46571532,0
2481
- 0,0.00094029406,0
2482
- 0,0.101251364,0
2483
- 1,0.7810546,1
2484
- 0,0.46070743,0
2485
- 0,0.49183476,0
2486
- 1,0.5577063,1
2487
- 0,0.0015362684,0
2488
- 0,0.4263195,0
2489
- 0,0.7231806,1
2490
- 0,0.37002757,0
2491
- 0,0.752774,1
2492
- 0,0.000886504,0
2493
- 1,0.22578251,0
2494
- 1,0.4892348,0
2495
- 0,0.16326025,0
2496
- 0,0.67995006,1
2497
- 0,0.0012501619,0
2498
- 0,0.0010966564,0
2499
- 0,0.0032683455,0
2500
- 0,0.4806333,0
2501
- 0,0.00090476696,0
2502
- 0,0.0015114724,0
2503
- 0,0.0008801575,0
2504
- 0,0.0013020836,0
2505
- 0,0.0008198137,0
2506
- 0,0.39454362,0
2507
- 0,0.00093733915,0
2508
- 1,0.2374296,0
2509
- 0,0.4772929,0
2510
- 0,0.00083000946,0
2511
- 1,0.39207602,0
2512
- 1,0.21760413,0
2513
- 0,0.00088210235,0
2514
- 0,0.4371137,0
2515
- 0,0.003961796,0
2516
- 1,0.407404,0
2517
- 1,0.32632312,0
2518
- 0,0.0013358905,0
2519
- 0,0.44445646,0
2520
- 0,0.00078979187,0
2521
- 0,0.001217051,0
2522
- 1,0.14045572,0
2523
- 0,0.00090959703,0
2524
- 0,0.12760106,0
2525
- 0,0.00092583685,0
2526
- 0,0.0010166396,0
2527
- 0,0.0013670418,0
2528
- 0,0.0011671932,0
2529
- 0,0.0009024935,0
2530
- 0,0.0009488663,0
2531
- 0,0.23633154,0
2532
- 0,0.0009680432,0
2533
- 0,0.005821207,0
2534
- 0,0.001028018,0
2535
- 1,0.8682412,1
2536
- 0,0.0031583507,0
2537
- 0,0.0009497762,0
2538
- 0,0.001053368,0
2539
- 0,0.32452446,0
2540
- 0,0.2644262,0
2541
- 0,0.0016570487,0
2542
- 1,0.41961834,0
2543
- 0,0.002138197,0
2544
- 0,0.50461143,1
2545
- 1,0.31562546,0
2546
- 0,0.00085728854,0
2547
- 1,0.7799668,1
2548
- 1,0.60938144,1
2549
- 1,0.70686054,1
2550
- 0,0.14313947,0
2551
- 0,0.0019669319,0
2552
- 0,0.0014327079,0
2553
- 1,0.35944593,0
2554
- 0,0.001155221,0
2555
- 0,0.0012412227,0
2556
- 0,0.65372455,1
2557
- 0,0.09133776,0
2558
- 0,0.00088334584,0
2559
- 0,0.312847,0
2560
- 0,0.00094589975,0
2561
- 0,0.02113242,0
2562
- 0,0.00085028814,0
2563
- 1,0.679118,1
2564
- 0,0.0010827653,0
2565
- 0,0.6941961,1
2566
- 0,0.22313194,0
2567
- 0,0.545091,1
2568
- 0,0.0009083876,0
2569
- 0,0.46453863,0
2570
- 0,0.26321778,0
2571
- 1,0.75012404,1
2572
- 0,0.0009906325,0
2573
- 1,0.6568036,1
2574
- 0,0.0009753449,0
2575
- 0,0.07195825,0
2576
- 1,0.6351574,1
2577
- 0,0.1408972,0
2578
- 0,0.001222103,0
2579
- 1,0.77539057,1
2580
- 0,0.5444014,1
2581
- 0,0.0010118657,0
2582
- 1,0.48471546,0
2583
- 0,0.0016892834,0
2584
- 0,0.0010253623,0
2585
- 0,0.52112716,1
2586
- 0,0.0014120834,0
2587
- 0,0.00082746346,0
2588
- 0,0.027199497,0
2589
- 0,0.05421511,0
2590
- 1,0.31981403,0
2591
- 0,0.0008672601,0
2592
- 0,0.00398289,0
2593
- 0,0.00096599694,0
2594
- 0,0.0008861936,0
2595
- 0,0.001077711,0
2596
- 0,0.0040958757,0
2597
- 1,0.73662597,1
2598
- 1,0.7852967,1
2599
- 0,0.0009536717,0
2600
- 0,0.22212833,0
2601
- 0,0.0012771386,0
2602
- 0,0.14009744,0
2603
- 0,0.22057438,0
2604
- 1,0.57046175,1
2605
- 1,0.77926654,1
2606
- 0,0.00082518166,0
2607
- 0,0.0009310219,0
2608
- 0,0.00091517466,0
2609
- 0,0.00088250585,0
2610
- 0,0.5822396,1
2611
- 0,0.00081105594,0
2612
- 1,0.2343802,0
2613
- 0,0.0011367487,0
2614
- 0,0.00117123,0
2615
- 1,0.60446376,1
2616
- 1,0.31510806,0
2617
- 0,0.2251927,0
2618
- 1,0.6748813,1
2619
- 0,0.0010593968,0
2620
- 1,0.32355276,0
2621
- 0,0.0011763906,0
2622
- 0,0.0021751213,0
2623
- 0,0.0026256007,0
2624
- 0,0.0009737066,0
2625
- 1,0.75336325,1
2626
- 0,0.0010090931,0
2627
- 0,0.00094570237,0
2628
- 0,0.17809871,0
2629
- 0,0.0008351583,0
2630
- 1,0.6894662,1
2631
- 0,0.16999531,0
2632
- 0,0.0009210333,0
2633
- 0,0.00092500576,0
2634
- 1,0.6222378,1
2635
- 0,0.5345441,1
2636
- 0,0.10253339,0
2637
- 1,0.5398624,1
2638
- 1,0.70694554,1
2639
- 0,0.14037986,0
2640
- 1,0.73327917,1
2641
- 1,0.46298742,0
2642
- 0,0.0013656756,0
2643
- 0,0.0009719117,0
2644
- 0,0.32651603,0
2645
- 0,0.0010383307,0
2646
- 0,0.00082847173,0
2647
- 0,0.3687943,0
2648
- 0,0.17860247,0
2649
- 0,0.8183667,1
2650
- 0,0.00082963787,0
2651
- 0,0.80729526,1
2652
- 0,0.15151022,0
2653
- 0,0.0011345402,0
2654
- 0,0.00081908173,0
2655
- 0,0.39142883,0
2656
- 0,0.0016399269,0
2657
- 0,0.27851093,0
2658
- 1,0.7010973,1
2659
- 0,0.007820642,0
2660
- 0,0.0018496078,0
2661
- 0,0.0019736353,0
2662
- 0,0.0009546348,0
2663
- 0,0.0014451279,0
2664
- 1,0.5517119,1
2665
- 0,0.010127031,0
2666
- 0,0.0009856458,0
2667
- 0,0.00091644377,0
2668
- 0,0.00095582794,0
2669
- 0,0.0011918532,0
2670
- 0,0.0010818146,0
2671
- 0,0.0013008608,0
2672
- 0,0.45276558,0
2673
- 0,0.001402461,0
2674
- 0,0.0010138496,0
2675
- 0,0.0019704031,0
2676
- 0,0.43915036,0
2677
- 0,0.0036121928,0
2678
- 1,0.7637573,1
2679
- 0,0.0010554554,0
2680
- 0,0.001165051,0
2681
- 0,0.37836757,0
2682
- 1,0.38657185,0
2683
- 0,0.00091857504,0
2684
- 1,0.7066505,1
2685
- 0,0.0009607843,0
2686
- 0,0.0009117129,0
2687
- 0,0.5112803,1
2688
- 0,0.0009124838,0
2689
- 0,0.0007892426,0
2690
- 0,0.0008387871,0
2691
- 1,0.57059675,1
2692
- 0,0.0008657701,0
2693
- 0,0.7825752,1
2694
- 0,0.0010295792,0
2695
- 0,0.0008365171,0
2696
- 1,0.791482,1
2697
- 1,0.11556648,0
2698
- 0,0.035818722,0
2699
- 1,0.67135835,1
2700
- 1,0.8084183,1
2701
- 0,0.0013094479,0
2702
- 0,0.7851747,1
2703
- 0,0.0012969027,0
2704
- 1,0.6496594,1
2705
- 1,0.7269179,1
2706
- 0,0.0010638919,0
2707
- 0,0.0008244636,0
2708
- 1,0.2654086,0
2709
- 0,0.33399442,0
2710
- 0,0.0028439101,0
2711
- 1,0.44304168,0
2712
- 0,0.0008866658,0
2713
- 0,0.00084273505,0
2714
- 0,0.0020351016,0
2715
- 0,0.0008314609,0
2716
- 0,0.0012523923,0
2717
- 0,0.00095604156,0
2718
- 0,0.0017479339,0
2719
- 0,0.051447127,0
2720
- 0,0.0009557702,0
2721
- 0,0.0009667441,0
2722
- 0,0.002198847,0
2723
- 0,0.12981664,0
2724
- 0,0.0012121114,0
2725
- 0,0.0079048155,0
2726
- 0,0.00077453733,0
2727
- 0,0.61904925,1
2728
- 1,0.51487285,1
2729
- 1,0.71629065,1
2730
- 0,0.0010181725,0
2731
- 1,0.20445453,0
2732
- 0,0.0009306721,0
2733
- 0,0.00092179497,0
2734
- 1,0.53567547,1
2735
- 0,0.0008878104,0
2736
- 0,0.14483893,0
2737
- 0,0.0022033292,0
2738
- 0,0.0012140451,0
2739
- 0,0.0011769638,0
2740
- 0,0.00094178476,0
2741
- 0,0.0013351414,0
2742
- 0,0.0010231428,0
2743
- 0,0.0028764764,0
2744
- 0,0.00088782865,0
2745
- 0,0.0010239044,0
2746
- 0,0.0008212328,0
2747
- 0,0.00097123464,0
2748
- 0,0.0011692323,0
2749
- 0,0.00095932,0
2750
- 0,0.0009633353,0
2751
- 0,0.35725975,0
2752
- 1,0.69560957,1
2753
- 0,0.0010163616,0
2754
- 0,0.00086547283,0
2755
- 1,0.8560721,1
2756
- 0,0.00092080346,0
2757
- 1,0.44005793,0
2758
- 0,0.0009872523,0
2759
- 0,0.0008560454,0
2760
- 0,0.0012110769,0
2761
- 0,0.0008555859,0
2762
- 0,0.20835932,0
2763
- 0,0.0012324532,0
2764
- 0,0.0011165539,0
2765
- 1,0.73311216,1
2766
- 0,0.0012142817,0
2767
- 1,0.72214335,1
2768
- 0,0.0012687293,0
2769
- 0,0.0007941535,0
2770
- 0,0.052786954,0
2771
- 0,0.0011843917,0
2772
- 0,0.0008150518,0
2773
- 0,0.0010267626,0
2774
- 1,0.6452032,1
2775
- 0,0.3575322,0
2776
- 1,0.7990738,1
2777
- 0,0.0010931097,0
2778
- 0,0.0016088708,0
2779
- 1,0.6296911,1
2780
- 1,0.20946522,0
2781
- 0,0.0010309608,0
2782
- 1,0.7263279,1
2783
- 1,0.63071215,1
2784
- 0,0.013318661,0
2785
- 0,0.0012621472,0
2786
- 1,0.72783095,1
2787
- 0,0.0011353592,0
2788
- 0,0.00095678604,0
2789
- 0,0.001396949,0
2790
- 1,0.55972546,1
2791
- 0,0.0010440184,0
2792
- 0,0.27277997,0
2793
- 0,0.00091744936,0
2794
- 0,0.23269582,0
2795
- 0,0.0008658086,0
2796
- 0,0.00085602904,0
2797
- 0,0.0010250561,0
2798
- 0,0.00084345846,0
2799
- 1,0.66574997,1
2800
- 0,0.0011059438,0
2801
- 0,0.0011090867,0
2802
- 1,0.33194953,0
2803
- 0,0.0011119393,0
2804
- 0,0.6470907,1
2805
- 1,0.7327194,1
2806
- 1,0.7710486,1
2807
- 0,0.11899223,0
2808
- 0,0.36645073,0
2809
- 0,0.110385925,0
2810
- 0,0.0010350458,0
2811
- 1,0.7796217,1
2812
- 0,0.00096380076,0
2813
- 0,0.49681517,0
2814
- 0,0.0010481027,0
2815
- 0,0.6842804,1
2816
- 0,0.13682581,0
2817
- 0,0.05944811,0
2818
- 0,0.0012068587,0
2819
- 0,0.005288887,0
2820
- 0,0.65641546,1
2821
- 0,0.0008991473,0
2822
- 0,0.00090689753,0
2823
- 0,0.46010032,0
2824
- 0,0.14822635,0
2825
- 0,0.0011970259,0
2826
- 0,0.4271773,0
2827
- 0,0.3656186,0
2828
- 0,0.8195093,1
2829
- 1,0.5701496,1
2830
- 0,0.4221881,0
2831
- 0,0.0010203084,0
2832
- 1,0.74877393,1
2833
- 0,0.0009964638,0
2834
- 0,0.09252277,0
2835
- 0,0.32190117,0
2836
- 0,0.2631963,0
2837
- 0,0.5640367,1
2838
- 0,0.000990022,0
2839
- 1,0.20268653,0
2840
- 0,0.11683034,0
2841
- 0,0.004007565,0
2842
- 0,0.0017538378,0
2843
- 1,0.78015506,1
2844
- 1,0.101948835,0
2845
- 0,0.4630364,0
2846
- 0,0.0012770256,0
2847
- 0,0.0008438017,0
2848
- 0,0.0009979869,0
2849
- 1,0.51085657,1
2850
- 0,0.00088811293,0
2851
- 0,0.001140972,0
2852
- 0,0.18803759,0
2853
- 0,0.003006536,0
2854
- 0,0.0011695281,0
2855
- 0,0.12834576,0
2856
- 1,0.44561812,0
2857
- 0,0.0015858911,0
2858
- 1,0.3685857,0
2859
- 1,0.21091098,0
2860
- 0,0.4187445,0
2861
- 1,0.78673923,1
2862
- 0,0.0011236598,0
2863
- 0,0.0010910762,0
2864
- 0,0.0009057033,0
2865
- 0,0.0014176408,0
2866
- 1,0.67693305,1
2867
- 0,0.04034407,0
2868
- 0,0.0008000367,0
2869
- 0,0.0010251978,0
2870
- 0,0.001136779,0
2871
- 0,0.0014142994,0
2872
- 0,0.0008644599,0
2873
- 0,0.0009202785,0
2874
- 0,0.005384131,0
2875
- 0,0.00090876326,0
2876
- 0,0.1600891,0
2877
- 0,0.37699142,0
2878
- 0,0.0010899481,0
2879
- 0,0.57948345,1
2880
- 0,0.00081056,0
2881
- 0,0.0012006091,0
2882
- 1,0.77439034,1
2883
- 0,0.00091627566,0
2884
- 0,0.5347407,1
2885
- 0,0.0009845204,0
2886
- 1,0.4189611,0
2887
- 1,0.53350735,1
2888
- 0,0.0010007151,0
2889
- 0,0.0020579454,0
2890
- 0,0.0012446969,0
2891
- 0,0.000945488,0
2892
- 0,0.4172918,0
2893
- 0,0.000938053,0
2894
- 0,0.5179231,1
2895
- 0,0.0011195596,0
2896
- 0,0.00086770067,0
2897
- 1,0.7677404,1
2898
- 0,0.40908203,0
2899
- 1,0.67145306,1
2900
- 0,0.0009983968,0
2901
- 0,0.4357218,0
2902
- 1,0.40684062,0
2903
- 0,0.52875257,1
2904
- 0,0.0009659417,0
2905
- 0,0.0008981062,0
2906
- 0,0.0009528588,0
2907
- 0,0.00095595414,0
2908
- 1,0.56750405,1
2909
- 0,0.0010651433,0
2910
- 1,0.13752264,0
2911
- 0,0.00079181977,0
2912
- 0,0.0008729538,0
2913
- 0,0.0015384284,0
2914
- 0,0.58224595,1
2915
- 0,0.000829641,0
2916
- 1,0.45509866,0
2917
- 0,0.0008586016,0
2918
- 0,0.0008085296,0
2919
- 0,0.0009752612,0
2920
- 1,0.65100336,1
2921
- 0,0.0012406152,0
2922
- 1,0.7015345,1
2923
- 0,0.0008786278,0
2924
- 0,0.0012137983,0
2925
- 0,0.00094313076,0
2926
- 0,0.47724065,0
2927
- 1,0.76214045,1
2928
- 0,0.0012182803,0
2929
- 1,0.7183965,1
2930
- 1,0.46909162,0
2931
- 0,0.00082692713,0
2932
- 0,0.001267697,0
2933
- 1,0.7449686,1
2934
- 0,0.0009520602,0
2935
- 1,0.2682325,0
2936
- 1,0.8336221,1
2937
- 0,0.05860209,0
2938
- 1,0.5426554,1
2939
- 0,0.0015246021,0
2940
- 0,0.0009006676,0
2941
- 0,0.0009767053,0
2942
- 0,0.00081073726,0
2943
- 1,0.54334396,1
2944
- 1,0.72451127,1
2945
- 1,0.64862955,1
2946
- 0,0.0015333411,0
2947
- 0,0.0009270903,0
2948
- 0,0.0008060023,0
2949
- 0,0.0015449867,0
2950
- 0,0.0008756904,0
2951
- 0,0.00089255895,0
2952
- 1,0.4240961,0
2953
- 0,0.0007905915,0
2954
- 1,0.6176073,1
2955
- 0,0.16892567,0
2956
- 0,0.0019416468,0
2957
- 0,0.0034173308,0
2958
- 1,0.48232934,0
2959
- 0,0.3651226,0
2960
- 1,0.5432746,1
2961
- 1,0.53853256,1
2962
- 0,0.049414482,0
2963
- 1,0.61953956,1
2964
- 0,0.00086147717,0
2965
- 0,0.5655044,1
2966
- 1,0.5725881,1
2967
- 0,0.38828382,0
2968
- 0,0.0008593283,0
2969
- 0,0.6609687,1
2970
- 0,0.0031328383,0
2971
- 0,0.0016064071,0
2972
- 0,0.16550626,0
2973
- 0,0.0010093974,0
2974
- 1,0.78001803,1
2975
- 0,0.0009410012,0
2976
- 0,0.019389188,0
2977
- 1,0.8118031,1
2978
- 0,0.0010585715,0
2979
- 0,0.0011178768,0
2980
- 0,0.51671094,1
2981
- 1,0.14671549,0
2982
- 0,0.0012277119,0
2983
- 0,0.10688765,0
2984
- 0,0.26730445,0
2985
- 1,0.41755715,0
2986
- 0,0.6532114,1
2987
- 1,0.5408069,1
2988
- 0,0.0011035037,0
2989
- 0,0.38721368,0
2990
- 0,0.0009238639,0
2991
- 0,0.00087872625,0
2992
- 1,0.4663372,0
2993
- 0,0.25327864,0
2994
- 0,0.0011890574,0
2995
- 0,0.08030304,0
2996
- 0,0.0012005063,0
2997
- 0,0.0008608408,0
2998
- 1,0.5602192,1
2999
- 1,0.7817263,1
3000
- 0,0.00083466305,0
3001
- 0,0.0010471261,0
3002
- 0,0.001124526,0
3003
- 0,0.0012333503,0
3004
- 0,0.012108352,0
3005
- 0,0.21776897,0
3006
- 0,0.0016868615,0
3007
- 1,0.82101417,1
3008
- 0,0.0009039936,0
3009
- 0,0.45584226,0
3010
- 0,0.45311502,0
3011
- 0,0.00096501835,0
3012
- 1,0.66155726,1
3013
- 0,0.0008598758,0
3014
- 0,0.49483657,0
3015
- 0,0.0012214681,0
3016
- 0,0.0012331072,0
3017
- 0,0.0007621751,0
3018
- 0,0.0063888626,0
3019
- 0,0.0009407718,0
3020
- 0,0.0008379926,0
3021
- 0,0.0012965859,0
3022
- 1,0.5956936,1
3023
- 1,0.27054062,0
3024
- 0,0.3623484,0
3025
- 1,0.7357546,1
3026
- 0,0.00154095,0
3027
- 0,0.003388907,0
3028
- 0,0.0011851212,0
3029
- 0,0.0010015048,0
3030
- 0,0.0009875342,0
3031
- 0,0.56412435,1
3032
- 0,0.17703854,0
3033
- 0,0.0049527115,0
3034
- 0,0.41469368,0
3035
- 0,0.0013113499,0
3036
- 0,0.25928286,0
3037
- 0,0.00080234197,0
3038
- 0,0.0011141081,0
3039
- 0,0.0009273143,0
3040
- 0,0.15817499,0
3041
- 1,0.55791485,1
3042
- 0,0.49433956,0
3043
- 0,0.44379577,0
3044
- 0,0.00097380485,0
3045
- 0,0.00081098836,0
3046
- 0,0.09645127,0
3047
- 0,0.0010052086,0
3048
- 0,0.0008312449,0
3049
- 0,0.0014976384,0
3050
- 0,0.7013783,1
3051
- 0,0.7363814,1
3052
- 0,0.0017869734,0
3053
- 0,0.5600073,1
3054
- 1,0.5933672,1
3055
- 1,0.7463558,1
3056
- 0,0.00087595376,0
3057
- 0,0.3700118,0
3058
- 0,0.4078047,0
3059
- 1,0.70503634,1
3060
- 0,0.00087403157,0
3061
- 1,0.59753084,1
3062
- 1,0.28408647,0
3063
- 1,0.346187,0
3064
- 0,0.0010321912,0
3065
- 0,0.0010446163,0
3066
- 0,0.0019384956,0
3067
- 0,0.2687528,0
3068
- 0,0.002051823,0
3069
- 0,0.00088337995,0
3070
- 0,0.0009888418,0
3071
- 1,0.518932,1
3072
- 0,0.0014419496,0
3073
- 0,0.28866056,0
3074
- 0,0.0018917412,0
3075
- 0,0.7118644,1
3076
- 0,0.0007667698,0
3077
- 0,0.0011144138,0
3078
- 0,0.0010100271,0
3079
- 0,0.0023653419,0
3080
- 1,0.5826993,1
3081
- 1,0.2670869,0
3082
- 0,0.0012896777,0
3083
- 0,0.00094871124,0
3084
- 0,0.0011521467,0
3085
- 0,0.0010139366,0
3086
- 0,0.081099555,0
3087
- 0,0.0012478646,0
3088
- 0,0.0009908755,0
3089
- 0,0.00085459027,0
3090
- 1,0.74782383,1
3091
- 1,0.63217145,1
3092
- 0,0.001039582,0
3093
- 0,0.0010658543,0
3094
- 0,0.0015608877,0
3095
- 0,0.00084603427,0
3096
- 1,0.69931394,1
3097
- 1,0.6545752,1
3098
- 0,0.0012512095,0
3099
- 0,0.25088048,0
3100
- 1,0.46437111,0
3101
- 0,0.00084795104,0
3102
- 0,0.0008775966,0
3103
- 0,0.0014866656,0
3104
- 0,0.0010228608,0
3105
- 1,0.7594941,1
3106
- 0,0.0010856481,0
3107
- 1,0.6442902,1
3108
- 0,0.0011075162,0
3109
- 1,0.7875049,1
3110
- 0,0.0009568744,0
3111
- 1,0.8581466,1
3112
- 0,0.00096314953,0
3113
- 0,0.0009732034,0
3114
- 0,0.0011903605,0
3115
- 0,0.000776502,0
3116
- 1,0.7685739,1
3117
- 0,0.0008151252,0
3118
- 0,0.001778643,0
3119
- 1,0.74879175,1
3120
- 0,0.20462619,0
3121
- 0,0.0010626859,0
3122
- 0,0.00089822855,0
3123
- 1,0.30061033,0
3124
- 0,0.20656452,0
3125
- 0,0.001094946,0
3126
- 1,0.7288647,1
3127
- 0,0.00084817404,0
3128
- 1,0.7852595,1
3129
- 0,0.0009366696,0
3130
- 1,0.75811154,1
3131
- 0,0.43617013,0
3132
- 0,0.0009005994,0
3133
- 0,0.0008882077,0
3134
- 0,0.0008365301,0
3135
- 1,0.3134185,0
3136
- 0,0.00086332444,0
3137
- 0,0.2912862,0
3138
- 0,0.0009431514,0
3139
- 0,0.0011384224,0
3140
- 1,0.7892175,1
3141
- 0,0.100751415,0
3142
- 0,0.00085538207,0
3143
- 0,0.00081628014,0
3144
- 0,0.14945064,0
3145
- 0,0.0011895921,0
3146
- 0,0.0008453175,0
3147
- 1,0.7127246,1
3148
- 0,0.0014821741,0
3149
- 0,0.48956093,0
3150
- 1,0.6526354,1
3151
- 0,0.70968366,1
3152
- 0,0.0013969796,0
3153
- 0,0.15343508,0
3154
- 0,0.0011304358,0
3155
- 1,0.69095427,1
3156
- 0,0.0013296176,0
3157
- 0,0.0012575921,0
3158
- 0,0.00092597713,0
3159
- 0,0.0008999059,0
3160
- 0,0.24098817,0
3161
- 0,0.0016370128,0
3162
- 0,0.0012018238,0
3163
- 0,0.2096467,0
3164
- 1,0.86404353,1
3165
- 0,0.022236077,0
3166
- 0,0.0010826504,0
3167
- 0,0.022941377,0
3168
- 1,0.7931688,1
3169
- 0,0.001005007,0
3170
- 0,0.0014776052,0
3171
- 0,0.007663459,0
3172
- 0,0.42009014,0
3173
- 0,0.0010014976,0
3174
- 0,0.0014163717,0
3175
- 1,0.28360623,0
3176
- 0,0.002345287,0
3177
- 0,0.0010593463,0
3178
- 0,0.14843744,0
3179
- 0,0.31841034,0
3180
- 0,0.0011871555,0
3181
- 1,0.58255094,1
3182
- 0,0.00079310685,0
3183
- 0,0.0015736327,0
3184
- 1,0.42560086,0
3185
- 1,0.79392105,1
3186
- 1,0.8037611,1
3187
- 0,0.15339166,0
3188
- 0,0.44212908,0
3189
- 0,0.15695691,0
3190
- 0,0.001069759,0
3191
- 0,0.0013624149,0
3192
- 0,0.10386715,0
3193
- 1,0.38520056,0
3194
- 0,0.00083144975,0
3195
- 1,0.51295435,1
3196
- 0,0.3737823,0
3197
- 0,0.00096323073,0
3198
- 0,0.0011089299,0
3199
- 0,0.00083140214,0
3200
- 0,0.001142048,0
3201
- 1,0.39718005,0
3202
- 0,0.6380984,1
3203
- 0,0.12990019,0
3204
- 0,0.6759345,1
3205
- 0,0.000828972,0
3206
- 0,0.0065064146,0
3207
- 0,0.0010996222,0
3208
- 0,0.5806185,1
3209
- 0,0.0030489217,0
3210
- 0,0.0009842621,0
3211
- 0,0.00096107053,0
3212
- 1,0.68338144,1
3213
- 0,0.0009416806,0
3214
- 0,0.001178365,0
3215
- 0,0.0016215707,0
3216
- 0,0.41308418,0
3217
- 0,0.37985086,0
3218
- 0,0.0010058667,0
3219
- 1,0.27062395,0
3220
- 0,0.0013959251,0
3221
- 1,0.6292908,1
3222
- 0,0.0013573767,0
3223
- 0,0.0028528608,0
3224
- 1,0.7515089,1
3225
- 1,0.58315617,1
3226
- 1,0.29764894,0
3227
- 0,0.0010242697,0
3228
- 0,0.66869485,1
3229
- 1,0.34468064,0
3230
- 0,0.0009077599,0
3231
- 0,0.0012983766,0
3232
- 1,0.6648164,1
3233
- 0,0.00096312474,0
3234
- 0,0.22477408,0
3235
- 0,0.0010165646,0
3236
- 0,0.0012150696,0
3237
- 0,0.001145921,0
3238
- 0,0.0008946433,0
3239
- 0,0.0009847601,0
3240
- 0,0.0019780556,0
3241
- 0,0.03994404,0
3242
- 0,0.39725897,0
3243
- 0,0.0011280741,0
3244
- 0,0.14372922,0
3245
- 0,0.00097693,0
3246
- 0,0.49378365,0
3247
- 0,0.0008469307,0
3248
- 1,0.23917605,0
3249
- 0,0.110021144,0
3250
- 0,0.00080322206,0
3251
- 0,0.0008781373,0
3252
- 0,0.002445047,0
3253
- 0,0.00089050585,0
3254
- 0,0.0012529906,0
3255
- 0,0.0012119152,0
3256
- 0,0.0016275514,0
3257
- 1,0.6807688,1
3258
- 0,0.0013215955,0
3259
- 0,0.78445035,1
3260
- 0,0.0008880186,0
3261
- 0,0.001284776,0
3262
- 0,0.0017074918,0
3263
- 0,0.0009681374,0
3264
- 0,0.6296154,1
3265
- 0,0.001095009,0
3266
- 0,0.0015562188,0
3267
- 1,0.6771481,1
3268
- 1,0.18929906,0
3269
- 0,0.0008950543,0
3270
- 0,0.0008272061,0
3271
- 0,0.0017716148,0
3272
- 1,0.76254493,1
3273
- 0,0.0009578705,0
3274
- 1,0.8615882,1
3275
- 0,0.00092509435,0
3276
- 0,0.0009841388,0
3277
- 0,0.0008676494,0
3278
- 0,0.0014497808,0
3279
- 0,0.0010180256,0
3280
- 0,0.0020955966,0
3281
- 0,0.0016175646,0
3282
- 0,0.0008392496,0
3283
- 1,0.6555392,1
3284
- 0,0.00096220616,0
3285
- 0,0.0012978838,0
3286
- 0,0.0020624246,0
3287
- 0,0.32421842,0
3288
- 0,0.0012323145,0
3289
- 1,0.15064883,0
3290
- 1,0.2249341,0
3291
- 1,0.6194747,1
3292
- 1,0.40883264,0
3293
- 0,0.00088793685,0
3294
- 0,0.5703692,1
3295
- 0,0.00080029253,0
3296
- 1,0.533915,1
3297
- 0,0.0018578873,0
3298
- 0,0.0009043615,0
3299
- 0,0.61236954,1
3300
- 1,0.25842515,0
3301
- 1,0.6920323,1
3302
- 0,0.0009899612,0
3303
- 1,0.21387953,0
3304
- 0,0.0023066083,0
3305
- 1,0.6479849,1
3306
- 0,0.0011913775,0
3307
- 0,0.0010240928,0
3308
- 0,0.00094176497,0
3309
- 0,0.2934143,0
3310
- 1,0.32174513,0
3311
- 0,0.00091399293,0
3312
- 0,0.3778847,0
3313
- 0,0.0010437181,0
3314
- 0,0.00093027006,0
3315
- 1,0.2693971,0
3316
- 1,0.695859,1
3317
- 0,0.0009045273,0
3318
- 0,0.21045512,0
3319
- 0,0.0008732916,0
3320
- 1,0.615637,1
3321
- 0,0.001324741,0
3322
- 0,0.00093508593,0
3323
- 0,0.00074671215,0
3324
- 0,0.33646423,0
3325
- 0,0.001223517,0
3326
- 0,0.34300792,0
3327
- 0,0.0029657178,0
3328
- 0,0.0010389285,0
3329
- 0,0.00087984215,0
3330
- 0,0.0009522997,0
3331
- 0,0.0009270687,0
3332
- 0,0.0010787433,0
3333
- 0,0.001082428,0
3334
- 0,0.0010898644,0
3335
- 0,0.0015669015,0
3336
- 0,0.50207824,1
3337
- 1,0.69201577,1
3338
- 0,0.0027307745,0
3339
- 1,0.65933526,1
3340
- 1,0.33638108,0
3341
- 0,0.0009279095,0
3342
- 1,0.6811094,1
3343
- 1,0.6222203,1
3344
- 0,0.00087204383,0
3345
- 0,0.001260481,0
3346
- 0,0.221174,0
3347
- 0,0.000790207,0
3348
- 0,0.000836068,0
3349
- 0,0.34802088,0
3350
- 0,0.0011530793,0
3351
- 0,0.0008934312,0
3352
- 0,0.000901762,0
3353
- 0,0.0068717427,0
3354
- 0,0.0012540269,0
3355
- 0,0.0008931409,0
3356
- 0,0.24012849,0
3357
- 0,0.00096857373,0
3358
- 0,0.0009792737,0
3359
- 1,0.73453856,1
3360
- 1,0.71891254,1
3361
- 0,0.0012205226,0
3362
- 1,0.5815703,1
3363
- 0,0.0017512442,0
3364
- 0,0.0008828233,0
3365
- 0,0.0039220187,0
3366
- 0,0.0013681627,0
3367
- 0,0.85258,1
3368
- 0,0.0010186047,0
3369
- 1,0.8146731,1
3370
- 0,0.05123307,0
3371
- 1,0.6553712,1
3372
- 0,0.0016004561,0
3373
- 0,0.12838942,0
3374
- 0,0.0009933491,0
3375
- 0,0.0009402028,0
3376
- 0,0.000918758,0
3377
- 1,0.8279392,1
3378
- 1,0.35954273,0
3379
- 1,0.76767814,1
3380
- 1,0.48393753,0
3381
- 1,0.53692836,1
3382
- 0,0.0008827944,0
3383
- 0,0.0009095459,0
3384
- 1,0.6520005,1
3385
- 0,0.0010726445,0
3386
- 0,0.0013205588,0
3387
- 0,0.0012526136,0
3388
- 1,0.518195,1
3389
- 0,0.48449913,0
3390
- 0,0.4973757,0
3391
- 0,0.0008689781,0
3392
- 1,0.6955714,1
3393
- 0,0.0027496573,0
3394
- 0,0.70566237,1
3395
- 0,0.0008412664,0
3396
- 0,0.0013987075,0
3397
- 1,0.13971777,0
3398
- 0,0.0012735971,0
3399
- 0,0.42210418,0
3400
- 1,0.64054,1
3401
- 0,0.0008632639,0
3402
- 0,0.0008357673,0
3403
- 0,0.0010104679,0
3404
- 0,0.0011264505,0
3405
- 0,0.0008627759,0
3406
- 1,0.63967973,1
3407
- 0,0.0015104539,0
3408
- 0,0.7630441,1
3409
- 0,0.25452375,0
3410
- 0,0.0008738256,0
3411
- 0,0.0010032041,0
3412
- 0,0.5034312,1
3413
- 0,0.00092107285,0
3414
- 0,0.0013620523,0
3415
- 1,0.91997075,1
3416
- 1,0.52912873,1
3417
- 0,0.0024464617,0
3418
- 0,0.72202426,1
3419
- 0,0.12549123,0
3420
- 1,0.3554386,0
3421
- 1,0.68960494,1
3422
- 0,0.0011416401,0
3423
- 0,0.062529385,0
3424
- 0,0.00089203403,0
3425
- 0,0.16887772,0
3426
- 0,0.00095955486,0
3427
- 1,0.350305,0
3428
- 0,0.12311296,0
3429
- 0,0.0024356688,0
3430
- 0,0.5414857,1
3431
- 0,0.0008605533,0
3432
- 1,0.7618979,1
3433
- 0,0.0009157459,0
3434
- 0,0.0010139028,0
3435
- 0,0.33930495,0
3436
- 0,0.0007996137,0
3437
- 0,0.0010549367,0
3438
- 0,0.003367107,0
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ef9845cf64c142ff16fc915402953a1383e36ecb1c76b6174fae75c0dec59cd4
3
+ size 54904
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
tokenizer/.ipynb_checkpoints/my_tokenizers-checkpoint.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import os
4
+ import re
5
+ import codecs
6
+ import unicodedata
7
+ from typing import List, Optional
8
+ from transformers import PreTrainedTokenizer
9
+ from SmilesPE.tokenizer import SPE_Tokenizer
10
+
11
+ def load_vocab(vocab_file):
12
+ """Loads a vocabulary file into a dictionary."""
13
+ vocab = collections.OrderedDict()
14
+ with open(vocab_file, "r", encoding="utf-8") as reader:
15
+ tokens = reader.readlines()
16
+ for index, token in enumerate(tokens):
17
+ token = token.rstrip("\n")
18
+ vocab[token] = index
19
+ return vocab
20
+
21
+ class Atomwise_Tokenizer(object):
22
+ """Run atom-level SMILES tokenization"""
23
+
24
+ def __init__(self):
25
+ """ Constructs a atom-level Tokenizer.
26
+ """
27
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
28
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
29
+
30
+ self.regex = re.compile(self.regex_pattern)
31
+
32
+ def tokenize(self, text):
33
+ """ Basic Tokenization of a SMILES.
34
+ """
35
+ tokens = [token for token in self.regex.findall(text)]
36
+ return tokens
37
+
38
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
39
+ r"""
40
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
41
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
42
+ should refer to the superclass for more information regarding methods.
43
+ Args:
44
+ vocab_file (:obj:`string`):
45
+ File containing the vocabulary.
46
+ spe_file (:obj:`string`):
47
+ File containing the trained SMILES Pair Encoding vocabulary.
48
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
49
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
50
+ token instead.
51
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
52
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
53
+ for sequence classification or for a text and a question for question answering.
54
+ It is also used as the last token of a sequence built with special tokens.
55
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
56
+ The token used for padding, for example when batching sequences of different lengths.
57
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
58
+ The classifier token which is used when doing sequence classification (classification of the whole
59
+ sequence instead of per-token classification). It is the first token of the sequence when built with
60
+ special tokens.
61
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
62
+ The token used for masking values. This is the token used when training this model with masked language
63
+ modeling. This is the token which the model will try to predict.
64
+ """
65
+
66
+ def __init__(self, vocab_file, spe_file,
67
+ unk_token="[UNK]",
68
+ sep_token="[SEP]",
69
+ pad_token="[PAD]",
70
+ cls_token="[CLS]",
71
+ mask_token="[MASK]",
72
+ **kwargs):
73
+ if not os.path.isfile(vocab_file):
74
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
75
+ if not os.path.isfile(spe_file):
76
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
77
+
78
+ self.vocab = load_vocab(vocab_file)
79
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
80
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
81
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
82
+
83
+ super().__init__(
84
+ unk_token=unk_token,
85
+ sep_token=sep_token,
86
+ pad_token=pad_token,
87
+ cls_token=cls_token,
88
+ mask_token=mask_token,
89
+ **kwargs)
90
+
91
+ @property
92
+ def vocab_size(self):
93
+ return len(self.vocab)
94
+
95
+ def get_vocab(self):
96
+ return dict(self.vocab, **self.added_tokens_encoder)
97
+
98
+ def _tokenize(self, text):
99
+ return self.spe_tokenizer.tokenize(text).split(' ')
100
+
101
+ def _convert_token_to_id(self, token):
102
+ """ Converts a token (str) in an id using the vocab. """
103
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
104
+
105
+ def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
106
+ text = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
107
+ return self.convert_tokens_to_string(text)
108
+
109
+ def _convert_id_to_token(self, index):
110
+ """Converts an index (integer) in a token (str) using the vocab."""
111
+ return self.ids_to_tokens.get(index, self.unk_token)
112
+
113
+ def convert_tokens_to_string(self, tokens):
114
+ """ Converts a sequence of tokens (string) in a single string. """
115
+ out_string = " ".join(tokens).replace(" ##", "").strip()
116
+ return out_string
117
+
118
+ def build_inputs_with_special_tokens(
119
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
120
+ ) -> List[int]:
121
+ """
122
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
123
+ by concatenating and adding special tokens.
124
+ A BERT sequence has the following format:
125
+ - single sequence: ``[CLS] X [SEP]``
126
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
127
+ Args:
128
+ token_ids_0 (:obj:`List[int]`):
129
+ List of IDs to which the special tokens will be added
130
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
131
+ Optional second list of IDs for sequence pairs.
132
+ Returns:
133
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
134
+ """
135
+ if token_ids_1 is None:
136
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
137
+ cls = [self.cls_token_id]
138
+ sep = [self.sep_token_id]
139
+ return cls + token_ids_0 + sep + token_ids_1 + sep
140
+
141
+ def get_special_tokens_mask(
142
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
143
+ ) -> List[int]:
144
+ """
145
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
146
+ special tokens using the tokenizer ``prepare_for_model`` method.
147
+ Args:
148
+ token_ids_0 (:obj:`List[int]`):
149
+ List of ids.
150
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
151
+ Optional second list of IDs for sequence pairs.
152
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
153
+ Set to True if the token list is already formatted with special tokens for the model
154
+ Returns:
155
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
156
+ """
157
+
158
+ if already_has_special_tokens:
159
+ if token_ids_1 is not None:
160
+ raise ValueError(
161
+ "You should not supply a second sequence if the provided sequence of "
162
+ "ids is already formated with special tokens for the model."
163
+ )
164
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
165
+
166
+ if token_ids_1 is not None:
167
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
168
+ return [1] + ([0] * len(token_ids_0)) + [1]
169
+
170
+ def create_token_type_ids_from_sequences(
171
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
172
+ ) -> List[int]:
173
+ """
174
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
175
+ A BERT sequence pair mask has the following format:
176
+ ::
177
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
178
+ | first sequence | second sequence |
179
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
180
+ Args:
181
+ token_ids_0 (:obj:`List[int]`):
182
+ List of ids.
183
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
184
+ Optional second list of IDs for sequence pairs.
185
+ Returns:
186
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
187
+ sequence(s).
188
+ """
189
+ sep = [self.sep_token_id]
190
+ cls = [self.cls_token_id]
191
+ if token_ids_1 is None:
192
+ return len(cls + token_ids_0 + sep) * [0]
193
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
194
+
195
+ def save_vocabulary(self, vocab_path):
196
+ """
197
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
198
+ Args:
199
+ vocab_path (:obj:`str`):
200
+ The directory in which to save the vocabulary.
201
+ Returns:
202
+ :obj:`Tuple(str)`: Paths to the files saved.
203
+ """
204
+ index = 0
205
+ if os.path.isdir(vocab_path):
206
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
207
+ else:
208
+ vocab_file = vocab_path
209
+ with open(vocab_file, "w", encoding="utf-8") as writer:
210
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
211
+ if index != token_index:
212
+ logger.warning(
213
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
214
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
215
+ )
216
+ index = token_index
217
+ writer.write(token + "\n")
218
+ index += 1
219
+ return (vocab_file,)
220
+
221
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
222
+ r"""
223
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
224
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
225
+ should refer to the superclass for more information regarding methods.
226
+ Args:
227
+ vocab_file (:obj:`string`):
228
+ File containing the vocabulary.
229
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
230
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
231
+ token instead.
232
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
233
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
234
+ for sequence classification or for a text and a question for question answering.
235
+ It is also used as the last token of a sequence built with special tokens.
236
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
237
+ The token used for padding, for example when batching sequences of different lengths.
238
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
239
+ The classifier token which is used when doing sequence classification (classification of the whole
240
+ sequence instead of per-token classification). It is the first token of the sequence when built with
241
+ special tokens.
242
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
243
+ The token used for masking values. This is the token used when training this model with masked language
244
+ modeling. This is the token which the model will try to predict.
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ vocab_file,
250
+ unk_token="[UNK]",
251
+ sep_token="[SEP]",
252
+ pad_token="[PAD]",
253
+ cls_token="[CLS]",
254
+ mask_token="[MASK]",
255
+ **kwargs
256
+ ):
257
+ super().__init__(
258
+ unk_token=unk_token,
259
+ sep_token=sep_token,
260
+ pad_token=pad_token,
261
+ cls_token=cls_token,
262
+ mask_token=mask_token,
263
+ **kwargs,
264
+ )
265
+
266
+ if not os.path.isfile(vocab_file):
267
+ raise ValueError(
268
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
269
+ )
270
+ self.vocab = load_vocab(vocab_file)
271
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
272
+ self.tokenizer = Atomwise_Tokenizer()
273
+
274
+ @property
275
+ def vocab_size(self):
276
+ return len(self.vocab)
277
+
278
+ def get_vocab(self):
279
+ return dict(self.vocab, **self.added_tokens_encoder)
280
+
281
+ def _tokenize(self, text):
282
+ return self.tokenizer.tokenize(text)
283
+
284
+ def _convert_token_to_id(self, token):
285
+ """ Converts a token (str) in an id using the vocab. """
286
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
287
+
288
+ def _convert_id_to_token(self, index):
289
+ """Converts an index (integer) in a token (str) using the vocab."""
290
+ return self.ids_to_tokens.get(index, self.unk_token)
291
+
292
+ def convert_tokens_to_string(self, tokens):
293
+ """ Converts a sequence of tokens (string) in a single string. """
294
+ out_string = " ".join(tokens).replace(" ##", "").strip()
295
+ return out_string
296
+
297
+ def build_inputs_with_special_tokens(
298
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
299
+ ) -> List[int]:
300
+ """
301
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
302
+ by concatenating and adding special tokens.
303
+ A BERT sequence has the following format:
304
+ - single sequence: ``[CLS] X [SEP]``
305
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
306
+ Args:
307
+ token_ids_0 (:obj:`List[int]`):
308
+ List of IDs to which the special tokens will be added
309
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
310
+ Optional second list of IDs for sequence pairs.
311
+ Returns:
312
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
313
+ """
314
+ if token_ids_1 is None:
315
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
316
+ cls = [self.cls_token_id]
317
+ sep = [self.sep_token_id]
318
+ return cls + token_ids_0 + sep + token_ids_1 + sep
319
+
320
+ def get_special_tokens_mask(
321
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
322
+ ) -> List[int]:
323
+ """
324
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
325
+ special tokens using the tokenizer ``prepare_for_model`` method.
326
+ Args:
327
+ token_ids_0 (:obj:`List[int]`):
328
+ List of ids.
329
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
330
+ Optional second list of IDs for sequence pairs.
331
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
332
+ Set to True if the token list is already formatted with special tokens for the model
333
+ Returns:
334
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
335
+ """
336
+
337
+ if already_has_special_tokens:
338
+ if token_ids_1 is not None:
339
+ raise ValueError(
340
+ "You should not supply a second sequence if the provided sequence of "
341
+ "ids is already formated with special tokens for the model."
342
+ )
343
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
344
+
345
+ if token_ids_1 is not None:
346
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
347
+ return [1] + ([0] * len(token_ids_0)) + [1]
348
+
349
+ def create_token_type_ids_from_sequences(
350
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
351
+ ) -> List[int]:
352
+ """
353
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
354
+ A BERT sequence pair mask has the following format:
355
+ ::
356
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
357
+ | first sequence | second sequence |
358
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ Returns:
365
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
366
+ sequence(s).
367
+ """
368
+ sep = [self.sep_token_id]
369
+ cls = [self.cls_token_id]
370
+ if token_ids_1 is None:
371
+ return len(cls + token_ids_0 + sep) * [0]
372
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
373
+
374
+ def save_vocabulary(self, vocab_path):
375
+ """
376
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
377
+ Args:
378
+ vocab_path (:obj:`str`):
379
+ The directory in which to save the vocabulary.
380
+ Returns:
381
+ :obj:`Tuple(str)`: Paths to the files saved.
382
+ """
383
+ index = 0
384
+ if os.path.isdir(vocab_path):
385
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
386
+ else:
387
+ vocab_file = vocab_path
388
+ with open(vocab_file, "w", encoding="utf-8") as writer:
389
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
390
+ if index != token_index:
391
+ logger.warning(
392
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
393
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
394
+ )
395
+ index = token_index
396
+ writer.write(token + "\n")
397
+ index += 1
398
+ return (vocab_file,)
tokenizer/__pycache__/my_tokenizers.cpython-310.pyc ADDED
Binary file (15.5 kB). View file
 
tokenizer/my_tokenizers.py ADDED
@@ -0,0 +1,398 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import collections
2
+ import logging
3
+ import os
4
+ import re
5
+ import codecs
6
+ import unicodedata
7
+ from typing import List, Optional
8
+ from transformers import PreTrainedTokenizer
9
+ from SmilesPE.tokenizer import SPE_Tokenizer
10
+
11
+ def load_vocab(vocab_file):
12
+ """Loads a vocabulary file into a dictionary."""
13
+ vocab = collections.OrderedDict()
14
+ with open(vocab_file, "r", encoding="utf-8") as reader:
15
+ tokens = reader.readlines()
16
+ for index, token in enumerate(tokens):
17
+ token = token.rstrip("\n")
18
+ vocab[token] = index
19
+ return vocab
20
+
21
+ class Atomwise_Tokenizer(object):
22
+ """Run atom-level SMILES tokenization"""
23
+
24
+ def __init__(self):
25
+ """ Constructs a atom-level Tokenizer.
26
+ """
27
+ # self.regex_pattern = r"(\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
28
+ self.regex_pattern = r"(\([^\(\)]{0,4}\)|\[[^\]]+]|Br?|Cl?|N|O|S|P|F|I|b|c|n|o|s|p|\(|\)|\.|=|#|-|\+|\\|\/\/?|:|~|@|\?|>>?|\*|\$|\%[0-9]{2}|[0-9])"
29
+
30
+ self.regex = re.compile(self.regex_pattern)
31
+
32
+ def tokenize(self, text):
33
+ """ Basic Tokenization of a SMILES.
34
+ """
35
+ tokens = [token for token in self.regex.findall(text)]
36
+ return tokens
37
+
38
+ class SMILES_SPE_Tokenizer(PreTrainedTokenizer):
39
+ r"""
40
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
41
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
42
+ should refer to the superclass for more information regarding methods.
43
+ Args:
44
+ vocab_file (:obj:`string`):
45
+ File containing the vocabulary.
46
+ spe_file (:obj:`string`):
47
+ File containing the trained SMILES Pair Encoding vocabulary.
48
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
49
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
50
+ token instead.
51
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
52
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
53
+ for sequence classification or for a text and a question for question answering.
54
+ It is also used as the last token of a sequence built with special tokens.
55
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
56
+ The token used for padding, for example when batching sequences of different lengths.
57
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
58
+ The classifier token which is used when doing sequence classification (classification of the whole
59
+ sequence instead of per-token classification). It is the first token of the sequence when built with
60
+ special tokens.
61
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
62
+ The token used for masking values. This is the token used when training this model with masked language
63
+ modeling. This is the token which the model will try to predict.
64
+ """
65
+
66
+ def __init__(self, vocab_file, spe_file,
67
+ unk_token="[UNK]",
68
+ sep_token="[SEP]",
69
+ pad_token="[PAD]",
70
+ cls_token="[CLS]",
71
+ mask_token="[MASK]",
72
+ **kwargs):
73
+ if not os.path.isfile(vocab_file):
74
+ raise ValueError("Can't find a vocabulary file at path '{}'.".format(vocab_file))
75
+ if not os.path.isfile(spe_file):
76
+ raise ValueError("Can't find a SPE vocabulary file at path '{}'.".format(spe_file))
77
+
78
+ self.vocab = load_vocab(vocab_file)
79
+ self.spe_vocab = open(spe_file, 'r', encoding='utf-8')
80
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
81
+ self.spe_tokenizer = SPE_Tokenizer(self.spe_vocab)
82
+
83
+ super().__init__(
84
+ unk_token=unk_token,
85
+ sep_token=sep_token,
86
+ pad_token=pad_token,
87
+ cls_token=cls_token,
88
+ mask_token=mask_token,
89
+ **kwargs)
90
+
91
+ @property
92
+ def vocab_size(self):
93
+ return len(self.vocab)
94
+
95
+ def get_vocab(self):
96
+ return dict(self.vocab, **self.added_tokens_encoder)
97
+
98
+ def _tokenize(self, text):
99
+ return self.spe_tokenizer.tokenize(text).split(' ')
100
+
101
+ def _convert_token_to_id(self, token):
102
+ """ Converts a token (str) in an id using the vocab. """
103
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
104
+
105
+ def decode(self, token_ids, skip_special_tokens=False, clean_up_tokenization_spaces=True):
106
+ text = self.convert_ids_to_tokens(token_ids, skip_special_tokens=skip_special_tokens)
107
+ return self.convert_tokens_to_string(text)
108
+
109
+ def _convert_id_to_token(self, index):
110
+ """Converts an index (integer) in a token (str) using the vocab."""
111
+ return self.ids_to_tokens.get(index, self.unk_token)
112
+
113
+ def convert_tokens_to_string(self, tokens):
114
+ """ Converts a sequence of tokens (string) in a single string. """
115
+ out_string = " ".join(tokens).replace(" ##", "").strip()
116
+ return out_string
117
+
118
+ def build_inputs_with_special_tokens(
119
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
120
+ ) -> List[int]:
121
+ """
122
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
123
+ by concatenating and adding special tokens.
124
+ A BERT sequence has the following format:
125
+ - single sequence: ``[CLS] X [SEP]``
126
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
127
+ Args:
128
+ token_ids_0 (:obj:`List[int]`):
129
+ List of IDs to which the special tokens will be added
130
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
131
+ Optional second list of IDs for sequence pairs.
132
+ Returns:
133
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
134
+ """
135
+ if token_ids_1 is None:
136
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
137
+ cls = [self.cls_token_id]
138
+ sep = [self.sep_token_id]
139
+ return cls + token_ids_0 + sep + token_ids_1 + sep
140
+
141
+ def get_special_tokens_mask(
142
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
143
+ ) -> List[int]:
144
+ """
145
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
146
+ special tokens using the tokenizer ``prepare_for_model`` method.
147
+ Args:
148
+ token_ids_0 (:obj:`List[int]`):
149
+ List of ids.
150
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
151
+ Optional second list of IDs for sequence pairs.
152
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
153
+ Set to True if the token list is already formatted with special tokens for the model
154
+ Returns:
155
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
156
+ """
157
+
158
+ if already_has_special_tokens:
159
+ if token_ids_1 is not None:
160
+ raise ValueError(
161
+ "You should not supply a second sequence if the provided sequence of "
162
+ "ids is already formated with special tokens for the model."
163
+ )
164
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
165
+
166
+ if token_ids_1 is not None:
167
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
168
+ return [1] + ([0] * len(token_ids_0)) + [1]
169
+
170
+ def create_token_type_ids_from_sequences(
171
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
172
+ ) -> List[int]:
173
+ """
174
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
175
+ A BERT sequence pair mask has the following format:
176
+ ::
177
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
178
+ | first sequence | second sequence |
179
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
180
+ Args:
181
+ token_ids_0 (:obj:`List[int]`):
182
+ List of ids.
183
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
184
+ Optional second list of IDs for sequence pairs.
185
+ Returns:
186
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
187
+ sequence(s).
188
+ """
189
+ sep = [self.sep_token_id]
190
+ cls = [self.cls_token_id]
191
+ if token_ids_1 is None:
192
+ return len(cls + token_ids_0 + sep) * [0]
193
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
194
+
195
+ def save_vocabulary(self, vocab_path):
196
+ """
197
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
198
+ Args:
199
+ vocab_path (:obj:`str`):
200
+ The directory in which to save the vocabulary.
201
+ Returns:
202
+ :obj:`Tuple(str)`: Paths to the files saved.
203
+ """
204
+ index = 0
205
+ if os.path.isdir(vocab_path):
206
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
207
+ else:
208
+ vocab_file = vocab_path
209
+ with open(vocab_file, "w", encoding="utf-8") as writer:
210
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
211
+ if index != token_index:
212
+ logger.warning(
213
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
214
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
215
+ )
216
+ index = token_index
217
+ writer.write(token + "\n")
218
+ index += 1
219
+ return (vocab_file,)
220
+
221
+ class SMILES_Atomwise_Tokenizer(PreTrainedTokenizer):
222
+ r"""
223
+ Constructs a SMILES tokenizer. Based on SMILES Pair Encoding (https://github.com/XinhaoLi74/SmilesPE).
224
+ This tokenizer inherits from :class:`~transformers.PreTrainedTokenizer` which contains most of the methods. Users
225
+ should refer to the superclass for more information regarding methods.
226
+ Args:
227
+ vocab_file (:obj:`string`):
228
+ File containing the vocabulary.
229
+ unk_token (:obj:`string`, `optional`, defaults to "[UNK]"):
230
+ The unknown token. A token that is not in the vocabulary cannot be converted to an ID and is set to be this
231
+ token instead.
232
+ sep_token (:obj:`string`, `optional`, defaults to "[SEP]"):
233
+ The separator token, which is used when building a sequence from multiple sequences, e.g. two sequences
234
+ for sequence classification or for a text and a question for question answering.
235
+ It is also used as the last token of a sequence built with special tokens.
236
+ pad_token (:obj:`string`, `optional`, defaults to "[PAD]"):
237
+ The token used for padding, for example when batching sequences of different lengths.
238
+ cls_token (:obj:`string`, `optional`, defaults to "[CLS]"):
239
+ The classifier token which is used when doing sequence classification (classification of the whole
240
+ sequence instead of per-token classification). It is the first token of the sequence when built with
241
+ special tokens.
242
+ mask_token (:obj:`string`, `optional`, defaults to "[MASK]"):
243
+ The token used for masking values. This is the token used when training this model with masked language
244
+ modeling. This is the token which the model will try to predict.
245
+ """
246
+
247
+ def __init__(
248
+ self,
249
+ vocab_file,
250
+ unk_token="[UNK]",
251
+ sep_token="[SEP]",
252
+ pad_token="[PAD]",
253
+ cls_token="[CLS]",
254
+ mask_token="[MASK]",
255
+ **kwargs
256
+ ):
257
+ super().__init__(
258
+ unk_token=unk_token,
259
+ sep_token=sep_token,
260
+ pad_token=pad_token,
261
+ cls_token=cls_token,
262
+ mask_token=mask_token,
263
+ **kwargs,
264
+ )
265
+
266
+ if not os.path.isfile(vocab_file):
267
+ raise ValueError(
268
+ "Can't find a vocabulary file at path '{}'.".format(vocab_file)
269
+ )
270
+ self.vocab = load_vocab(vocab_file)
271
+ self.ids_to_tokens = collections.OrderedDict([(ids, tok) for tok, ids in self.vocab.items()])
272
+ self.tokenizer = Atomwise_Tokenizer()
273
+
274
+ @property
275
+ def vocab_size(self):
276
+ return len(self.vocab)
277
+
278
+ def get_vocab(self):
279
+ return dict(self.vocab, **self.added_tokens_encoder)
280
+
281
+ def _tokenize(self, text):
282
+ return self.tokenizer.tokenize(text)
283
+
284
+ def _convert_token_to_id(self, token):
285
+ """ Converts a token (str) in an id using the vocab. """
286
+ return self.vocab.get(token, self.vocab.get(self.unk_token))
287
+
288
+ def _convert_id_to_token(self, index):
289
+ """Converts an index (integer) in a token (str) using the vocab."""
290
+ return self.ids_to_tokens.get(index, self.unk_token)
291
+
292
+ def convert_tokens_to_string(self, tokens):
293
+ """ Converts a sequence of tokens (string) in a single string. """
294
+ out_string = " ".join(tokens).replace(" ##", "").strip()
295
+ return out_string
296
+
297
+ def build_inputs_with_special_tokens(
298
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
299
+ ) -> List[int]:
300
+ """
301
+ Build model inputs from a sequence or a pair of sequence for sequence classification tasks
302
+ by concatenating and adding special tokens.
303
+ A BERT sequence has the following format:
304
+ - single sequence: ``[CLS] X [SEP]``
305
+ - pair of sequences: ``[CLS] A [SEP] B [SEP]``
306
+ Args:
307
+ token_ids_0 (:obj:`List[int]`):
308
+ List of IDs to which the special tokens will be added
309
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
310
+ Optional second list of IDs for sequence pairs.
311
+ Returns:
312
+ :obj:`List[int]`: list of `input IDs <../glossary.html#input-ids>`__ with the appropriate special tokens.
313
+ """
314
+ if token_ids_1 is None:
315
+ return [self.cls_token_id] + token_ids_0 + [self.sep_token_id]
316
+ cls = [self.cls_token_id]
317
+ sep = [self.sep_token_id]
318
+ return cls + token_ids_0 + sep + token_ids_1 + sep
319
+
320
+ def get_special_tokens_mask(
321
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None, already_has_special_tokens: bool = False
322
+ ) -> List[int]:
323
+ """
324
+ Retrieves sequence ids from a token list that has no special tokens added. This method is called when adding
325
+ special tokens using the tokenizer ``prepare_for_model`` method.
326
+ Args:
327
+ token_ids_0 (:obj:`List[int]`):
328
+ List of ids.
329
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
330
+ Optional second list of IDs for sequence pairs.
331
+ already_has_special_tokens (:obj:`bool`, `optional`, defaults to :obj:`False`):
332
+ Set to True if the token list is already formatted with special tokens for the model
333
+ Returns:
334
+ :obj:`List[int]`: A list of integers in the range [0, 1]: 1 for a special token, 0 for a sequence token.
335
+ """
336
+
337
+ if already_has_special_tokens:
338
+ if token_ids_1 is not None:
339
+ raise ValueError(
340
+ "You should not supply a second sequence if the provided sequence of "
341
+ "ids is already formated with special tokens for the model."
342
+ )
343
+ return list(map(lambda x: 1 if x in [self.sep_token_id, self.cls_token_id] else 0, token_ids_0))
344
+
345
+ if token_ids_1 is not None:
346
+ return [1] + ([0] * len(token_ids_0)) + [1] + ([0] * len(token_ids_1)) + [1]
347
+ return [1] + ([0] * len(token_ids_0)) + [1]
348
+
349
+ def create_token_type_ids_from_sequences(
350
+ self, token_ids_0: List[int], token_ids_1: Optional[List[int]] = None
351
+ ) -> List[int]:
352
+ """
353
+ Creates a mask from the two sequences passed to be used in a sequence-pair classification task.
354
+ A BERT sequence pair mask has the following format:
355
+ ::
356
+ 0 0 0 0 0 0 0 0 0 0 0 1 1 1 1 1 1 1 1 1
357
+ | first sequence | second sequence |
358
+ if token_ids_1 is None, only returns the first portion of the mask (0's).
359
+ Args:
360
+ token_ids_0 (:obj:`List[int]`):
361
+ List of ids.
362
+ token_ids_1 (:obj:`List[int]`, `optional`, defaults to :obj:`None`):
363
+ Optional second list of IDs for sequence pairs.
364
+ Returns:
365
+ :obj:`List[int]`: List of `token type IDs <../glossary.html#token-type-ids>`_ according to the given
366
+ sequence(s).
367
+ """
368
+ sep = [self.sep_token_id]
369
+ cls = [self.cls_token_id]
370
+ if token_ids_1 is None:
371
+ return len(cls + token_ids_0 + sep) * [0]
372
+ return len(cls + token_ids_0 + sep) * [0] + len(token_ids_1 + sep) * [1]
373
+
374
+ def save_vocabulary(self, vocab_path):
375
+ """
376
+ Save the sentencepiece vocabulary (copy original file) and special tokens file to a directory.
377
+ Args:
378
+ vocab_path (:obj:`str`):
379
+ The directory in which to save the vocabulary.
380
+ Returns:
381
+ :obj:`Tuple(str)`: Paths to the files saved.
382
+ """
383
+ index = 0
384
+ if os.path.isdir(vocab_path):
385
+ vocab_file = os.path.join(vocab_path, VOCAB_FILES_NAMES["vocab_file"])
386
+ else:
387
+ vocab_file = vocab_path
388
+ with open(vocab_file, "w", encoding="utf-8") as writer:
389
+ for token, token_index in sorted(self.vocab.items(), key=lambda kv: kv[1]):
390
+ if index != token_index:
391
+ logger.warning(
392
+ "Saving vocabulary to {}: vocabulary indices are not consecutive."
393
+ " Please check that the vocabulary is not corrupted!".format(vocab_file)
394
+ )
395
+ index = token_index
396
+ writer.write(token + "\n")
397
+ index += 1
398
+ return (vocab_file,)
tokenizer/new_splits.txt ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ c 1
2
+ c 2
3
+ c 3
4
+ c 4
5
+ c 5
6
+ c 6
7
+ c 7
8
+ c 8
9
+ c 9
10
+ ( c1
11
+ ( c2
12
+ c1 )
13
+ c2 )
14
+ n 1
15
+ n 2
16
+ n 3
17
+ n 4
18
+ n 5
19
+ n 6
20
+ n 7
21
+ n 8
22
+ n 9
23
+ ( n1
24
+ ( n2
25
+ n1 )
26
+ n2 )
27
+ O 1
28
+ O 2
29
+ O 3
30
+ O 4
31
+ O 5
32
+ O 6
33
+ O 7
34
+ O 8
35
+ O 9
36
+ ( O1
37
+ ( O2
38
+ O2 )
39
+ O2 )
40
+ = O
41
+ = C
42
+ = c
43
+ = N
44
+ = n
45
+ =C C
46
+ =C N
47
+ =C c
48
+ =c c
49
+ =N C
50
+ =N c
51
+ =n C
52
+ =n c
53
+ # N
54
+ # C
55
+ #N C
56
+ #C C
57
+ #C N
58
+ #N N
59
+ ( C
60
+ C )
61
+ ( O
62
+ O )
63
+ ( N
64
+ N )
65
+ Br c
66
+ ( =O
67
+ (=O )
68
+ C (=O)
69
+ C =O
70
+ C =N
71
+ C #N
72
+ C #C
73
+ C C
74
+ CC C
75
+ CC N
76
+ CC O
77
+ CC S
78
+ CC c
79
+ CC n
80
+ C N
81
+ CN C
82
+ CN c
83
+ C O
84
+ CO C
85
+ CO N
86
+ CO c
87
+ C S
88
+ CS C
89
+ CS S
90
+ CS c
91
+ C c
92
+ Cl c
93
+ C n
94
+ F c
95
+ N C
96
+ NC C
97
+ NC c
98
+ N N
99
+ N O
100
+ N c
101
+ N n
102
+ O C
103
+ OC C
104
+ OC O
105
+ OC c
106
+ O N
107
+ O O
108
+ O c
109
+ S C
110
+ SC C
111
+ SC c
112
+ S S
113
+ S c
114
+ c c
115
+ cc c
116
+ cc n
117
+ cc o
118
+ cc s
119
+ cc cc
120
+ c n
121
+ cn c
122
+ cn n
123
+ c o
124
+ co c
125
+ c s
126
+ cs c
127
+ cs n
128
+ n c
129
+ nc c
130
+ nc n
131
+ nc o
132
+ nc s
133
+ n n
134
+ nn c
135
+ nn n
136
+ n o
137
+ no c
138
+ no n
139
+ n s
140
+ ns c
141
+ ns n
142
+ o c
143
+ oc c
144
+ o n
145
+ s c
146
+ sc c
147
+ sc n
148
+ s n
149
+ N P
150
+ P N
151
+ C P
152
+ P C
153
+ N S
154
+ S N
155
+ C S
156
+ S C
157
+ S P
158
+ P S
159
+ C I
tokenizer/new_vocab.txt ADDED
@@ -0,0 +1,586 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ [PAD]
2
+ [UNK]
3
+ [CLS]
4
+ [SEP]
5
+ [MASK]
6
+ #
7
+ %
8
+ (
9
+ )
10
+ +
11
+ -
12
+ /
13
+ 0
14
+ 1
15
+ 2
16
+ 3
17
+ 4
18
+ 5
19
+ 6
20
+ 7
21
+ 8
22
+ 9
23
+ =
24
+ @
25
+ A
26
+ B
27
+ Br
28
+ Brc
29
+ C
30
+ CC
31
+ CCC
32
+ CCN
33
+ CCO
34
+ CCS
35
+ CCc
36
+ CCn
37
+ CN
38
+ CNC
39
+ CNc
40
+ CO
41
+ COC
42
+ CON
43
+ COc
44
+ CS
45
+ CSC
46
+ CSS
47
+ CSc
48
+ Cc
49
+ Cl
50
+ Clc
51
+ Cn
52
+ F
53
+ Fc
54
+ H
55
+ I
56
+ K
57
+ L
58
+ M
59
+ N
60
+ NC
61
+ NCC
62
+ NCc
63
+ NN
64
+ NO
65
+ Nc
66
+ Nn
67
+ O
68
+ OC
69
+ OCC
70
+ OCO
71
+ OCc
72
+ ON
73
+ OO
74
+ Oc
75
+ P
76
+ R
77
+ S
78
+ SC
79
+ SCC
80
+ SCc
81
+ SS
82
+ Sc
83
+ T
84
+ X
85
+ Z
86
+ [
87
+ \\
88
+ (/
89
+ ]
90
+ a
91
+ b
92
+ c
93
+ cc
94
+ ccc
95
+ ccn
96
+ cco
97
+ ccs
98
+ cn
99
+ cnc
100
+ cnn
101
+ co
102
+ coc
103
+ cs
104
+ csc
105
+ csn
106
+ e
107
+ g
108
+ i
109
+ l
110
+ n
111
+ nc
112
+ ncc
113
+ ncn
114
+ nco
115
+ ncs
116
+ nn
117
+ nnc
118
+ nnn
119
+ no
120
+ noc
121
+ non
122
+ ns
123
+ nsc
124
+ nsn
125
+ o
126
+ oc
127
+ occ
128
+ on
129
+ p
130
+ r
131
+ s
132
+ sc
133
+ scc
134
+ scn
135
+ sn
136
+ t
137
+ c1
138
+ c2
139
+ c3
140
+ c4
141
+ c5
142
+ c6
143
+ c7
144
+ c8
145
+ c9
146
+ n1
147
+ n2
148
+ n3
149
+ n4
150
+ n5
151
+ n6
152
+ n7
153
+ n8
154
+ n9
155
+ O1
156
+ O2
157
+ O3
158
+ O4
159
+ O5
160
+ O6
161
+ O7
162
+ O8
163
+ O9
164
+ (c1
165
+ (c2
166
+ c1)
167
+ c2)
168
+ (n1
169
+ (n2
170
+ n1)
171
+ n2)
172
+ (O1
173
+ (O2
174
+ O2)
175
+ =O
176
+ =C
177
+ =c
178
+ =N
179
+ =n
180
+ =CC
181
+ =CN
182
+ =Cc
183
+ =cc
184
+ =NC
185
+ =Nc
186
+ =nC
187
+ =nc
188
+ #C
189
+ #CC
190
+ #CN
191
+ #N
192
+ #NC
193
+ #NN
194
+ (C
195
+ C)
196
+ (O
197
+ O)
198
+ (N
199
+ N)
200
+ NP
201
+ PN
202
+ CP
203
+ PC
204
+ NS
205
+ SN
206
+ SP
207
+ PS
208
+ C(=O)
209
+ (/Br)
210
+ (/C#N)
211
+ (/C)
212
+ (/C=N)
213
+ (/C=O)
214
+ (/CBr)
215
+ (/CC)
216
+ (/CCC)
217
+ (/CCF)
218
+ (/CCN)
219
+ (/CCO)
220
+ (/CCl)
221
+ (/CI)
222
+ (/CN)
223
+ (/CO)
224
+ (/CS)
225
+ (/Cl)
226
+ (/F)
227
+ (/I)
228
+ (/N)
229
+ (/NC)
230
+ (/NCC)
231
+ (/NO)
232
+ (/O)
233
+ (/OC)
234
+ (/OCC)
235
+ (/S)
236
+ (/SC)
237
+ (=C)
238
+ (=C/C)
239
+ (=C/F)
240
+ (=C/I)
241
+ (=C/N)
242
+ (=C/O)
243
+ (=CBr)
244
+ (=CC)
245
+ (=CCF)
246
+ (=CCN)
247
+ (=CCO)
248
+ (=CCl)
249
+ (=CF)
250
+ (=CI)
251
+ (=CN)
252
+ (=CO)
253
+ (=C\\C)
254
+ (=C\\F)
255
+ (=C\\I)
256
+ (=C\\N)
257
+ (=C\\O)
258
+ (=N)
259
+ (=N/C)
260
+ (=N/N)
261
+ (=N/O)
262
+ (=NBr)
263
+ (=NC)
264
+ (=NCC)
265
+ (=NCl)
266
+ (=NN)
267
+ (=NO)
268
+ (=NOC)
269
+ (=N\\C)
270
+ (=N\\N)
271
+ (=N\\O)
272
+ (=O)
273
+ (=S)
274
+ (B)
275
+ (Br)
276
+ (C#C)
277
+ (C#CC)
278
+ (C#CI)
279
+ (C#CO)
280
+ (C#N)
281
+ (C#SN)
282
+ (C)
283
+ (C=C)
284
+ (C=CF)
285
+ (C=CI)
286
+ (C=N)
287
+ (C=NN)
288
+ (C=NO)
289
+ (C=O)
290
+ (C=S)
291
+ (CBr)
292
+ (CC#C)
293
+ (CC#N)
294
+ (CC)
295
+ (CC=C)
296
+ (CC=O)
297
+ (CCBr)
298
+ (CCC)
299
+ (CCCC)
300
+ (CCCF)
301
+ (CCCI)
302
+ (CCCN)
303
+ (CCCO)
304
+ (CCCS)
305
+ (CCCl)
306
+ (CCF)
307
+ (CCI)
308
+ (CCN)
309
+ (CCNC)
310
+ (CCNN)
311
+ (CCNO)
312
+ (CCO)
313
+ (CCOC)
314
+ (CCON)
315
+ (CCS)
316
+ (CCSC)
317
+ (CCl)
318
+ (CF)
319
+ (CI)
320
+ (CN)
321
+ (CN=O)
322
+ (CNC)
323
+ (CNCC)
324
+ (CNCO)
325
+ (CNN)
326
+ (CNNC)
327
+ (CNO)
328
+ (CNOC)
329
+ (CO)
330
+ (COC)
331
+ (COCC)
332
+ (COCI)
333
+ (COCN)
334
+ (COCO)
335
+ (COF)
336
+ (CON)
337
+ (COO)
338
+ (CS)
339
+ (CSC)
340
+ (CSCC)
341
+ (CSCF)
342
+ (CSO)
343
+ (Cl)
344
+ (F)
345
+ (I)
346
+ (N)
347
+ (N=N)
348
+ (N=NO)
349
+ (N=O)
350
+ (N=S)
351
+ (NBr)
352
+ (NC#N)
353
+ (NC)
354
+ (NC=N)
355
+ (NC=O)
356
+ (NC=S)
357
+ (NCBr)
358
+ (NCC)
359
+ (NCCC)
360
+ (NCCF)
361
+ (NCCN)
362
+ (NCCO)
363
+ (NCCS)
364
+ (NCCl)
365
+ (NCNC)
366
+ (NCO)
367
+ (NCS)
368
+ (NCl)
369
+ (NN)
370
+ (NN=O)
371
+ (NNC)
372
+ (NO)
373
+ (NOC)
374
+ (O)
375
+ (OC#N)
376
+ (OC)
377
+ (OC=C)
378
+ (OC=O)
379
+ (OC=S)
380
+ (OCBr)
381
+ (OCC)
382
+ (OCCC)
383
+ (OCCF)
384
+ (OCCI)
385
+ (OCCN)
386
+ (OCCO)
387
+ (OCCS)
388
+ (OCCl)
389
+ (OCF)
390
+ (OCI)
391
+ (OCO)
392
+ (OCOC)
393
+ (OCON)
394
+ (OCSC)
395
+ (OCl)
396
+ (OI)
397
+ (ON)
398
+ (OO)
399
+ (OOC)
400
+ (OOCC)
401
+ (OOSN)
402
+ (OSC)
403
+ (P)
404
+ (S)
405
+ (SC#N)
406
+ (SC)
407
+ (SCC)
408
+ (SCCC)
409
+ (SCCF)
410
+ (SCCN)
411
+ (SCCO)
412
+ (SCCS)
413
+ (SCCl)
414
+ (SCF)
415
+ (SCN)
416
+ (SCOC)
417
+ (SCSC)
418
+ (SCl)
419
+ (SI)
420
+ (SN)
421
+ (SN=O)
422
+ (SO)
423
+ (SOC)
424
+ (SOOO)
425
+ (SS)
426
+ (SSC)
427
+ (SSCC)
428
+ ([At])
429
+ ([O-])
430
+ ([O])
431
+ ([S-])
432
+ (\\Br)
433
+ (\\C#N)
434
+ (\\C)
435
+ (\\C=N)
436
+ (\\C=O)
437
+ (\\CBr)
438
+ (\\CC)
439
+ (\\CCC)
440
+ (\\CCO)
441
+ (\\CCl)
442
+ (\\CF)
443
+ (\\CN)
444
+ (\\CNC)
445
+ (\\CO)
446
+ (\\COC)
447
+ (\\Cl)
448
+ (\\F)
449
+ (\\I)
450
+ (\\N)
451
+ (\\NC)
452
+ (\\NCC)
453
+ (\\NN)
454
+ (\\NO)
455
+ (\\NOC)
456
+ (\\O)
457
+ (\\OC)
458
+ (\\OCC)
459
+ (\\ON)
460
+ (\\S)
461
+ (\\SC)
462
+ (\\SCC)
463
+ [Ag+]
464
+ [Ag-4]
465
+ [Ag]
466
+ [Al-3]
467
+ [Al]
468
+ [As+]
469
+ [AsH3]
470
+ [AsH]
471
+ [As]
472
+ [At]
473
+ [B-]
474
+ [B@-]
475
+ [B@@-]
476
+ [BH-]
477
+ [BH2-]
478
+ [BH3-]
479
+ [B]
480
+ [Ba]
481
+ [Br+2]
482
+ [BrH]
483
+ [Br]
484
+ [C+]
485
+ [C-]
486
+ [C@@H]
487
+ [C@@]
488
+ [C@H]
489
+ [C@]
490
+ [CH-]
491
+ [CH2]
492
+ [CH3]
493
+ [CH]
494
+ [C]
495
+ [CaH2]
496
+ [Ca]
497
+ [Cl+2]
498
+ [Cl+3]
499
+ [Cl+]
500
+ [Cs]
501
+ [FH]
502
+ [F]
503
+ [H]
504
+ [He]
505
+ [I+2]
506
+ [I+3]
507
+ [I+]
508
+ [IH]
509
+ [I]
510
+ [K]
511
+ [Kr]
512
+ [Li+]
513
+ [LiH]
514
+ [MgH2]
515
+ [Mg]
516
+ [N+]
517
+ [N-]
518
+ [N@+]
519
+ [N@@+]
520
+ [N@@]
521
+ [N@]
522
+ [NH+]
523
+ [NH-]
524
+ [NH2+]
525
+ [NH3]
526
+ [NH]
527
+ [N]
528
+ [Na]
529
+ [O+]
530
+ [O-]
531
+ [OH+]
532
+ [OH2]
533
+ [OH]
534
+ [O]
535
+ [P+]
536
+ [P@+]
537
+ [P@@+]
538
+ [P@@]
539
+ [P@]
540
+ [PH2]
541
+ [PH]
542
+ [P]
543
+ [Ra]
544
+ [Rb]
545
+ [S+]
546
+ [S-]
547
+ [S@+]
548
+ [S@@+]
549
+ [S@@]
550
+ [S@]
551
+ [SH+]
552
+ [SH2]
553
+ [SH]
554
+ [S]
555
+ [Se+]
556
+ [Se-2]
557
+ [SeH2]
558
+ [SeH]
559
+ [Se]
560
+ [Si@]
561
+ [SiH2]
562
+ [SiH]
563
+ [Si]
564
+ [SrH2]
565
+ [TeH]
566
+ [Te]
567
+ [Xe]
568
+ [Zn+2]
569
+ [Zn-2]
570
+ [Zn]
571
+ [b-]
572
+ [c+]
573
+ [c-]
574
+ [cH-]
575
+ [cH]
576
+ [c]
577
+ [n+]
578
+ [n-]
579
+ [nH]
580
+ [n]
581
+ [o+]
582
+ [s+]
583
+ [se+]
584
+ [se]
585
+ [te+]
586
+ [te]
training_classifiers/.gitignore ADDED
File without changes
training_classifiers/.ipynb_checkpoints/binding_affinity_iptm-checkpoint.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ extract_iptm_affinity_csv_all.py
4
+
5
+ Writes:
6
+ - out_dir/wt_iptm_affinity_all.csv
7
+ - out_dir/smiles_iptm_affinity_all.csv
8
+
9
+ Also prints:
10
+ - N
11
+ - Spearman rho (affinity vs iptm)
12
+ - Pearson r (affinity vs iptm)
13
+ """
14
+
15
+ from pathlib import Path
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+
20
+ def corr_stats(df: pd.DataFrame, x: str, y: str):
21
+ # pandas handles NaNs if we already dropped them; still be safe
22
+ xx = pd.to_numeric(df[x], errors="coerce")
23
+ yy = pd.to_numeric(df[y], errors="coerce")
24
+ m = xx.notna() & yy.notna()
25
+ xx = xx[m]
26
+ yy = yy[m]
27
+ n = int(m.sum())
28
+
29
+ # Pearson r
30
+ pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
31
+ # Spearman rho
32
+ spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
33
+
34
+ return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
35
+
36
+
37
+ def clean_one(
38
+ in_csv: Path,
39
+ out_csv: Path,
40
+ iptm_col: str,
41
+ affinity_col: str = "affinity",
42
+ keep_cols=(),
43
+ ):
44
+ df = pd.read_csv(in_csv)
45
+
46
+ # affinity + iptm must exist
47
+ need = [affinity_col, iptm_col]
48
+ missing = [c for c in need if c not in df.columns]
49
+ if missing:
50
+ raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
51
+
52
+ # coerce numeric
53
+ df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
54
+ df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
55
+
56
+ # drop NaNs in either
57
+ df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
58
+
59
+ # output cols (standardize names)
60
+ out = pd.DataFrame({
61
+ "affinity": df[affinity_col].astype(float),
62
+ "iptm": df[iptm_col].astype(float),
63
+ })
64
+
65
+ # keep split if present (handy for coloring later, but not used for corr)
66
+ if "split" in df.columns:
67
+ out.insert(0, "split", df["split"].astype(str))
68
+
69
+ # optional extras for labeling/debug
70
+ for c in keep_cols:
71
+ if c in df.columns:
72
+ out[c] = df[c]
73
+
74
+ out_csv.parent.mkdir(parents=True, exist_ok=True)
75
+ out.to_csv(out_csv, index=False)
76
+
77
+ stats = corr_stats(out, "iptm", "affinity")
78
+ print(f"[write] {out_csv}")
79
+ print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
80
+
81
+ # also save stats json next to csv
82
+ stats_path = out_csv.with_suffix(".stats.json")
83
+ with open(stats_path, "w") as f:
84
+ import json
85
+ json.dump(
86
+ {
87
+ "input_csv": str(in_csv),
88
+ "output_csv": str(out_csv),
89
+ "iptm_col": iptm_col,
90
+ "affinity_col": affinity_col,
91
+ **stats,
92
+ },
93
+ f,
94
+ indent=2,
95
+ )
96
+
97
+
98
+ def main():
99
+ import argparse
100
+ ap = argparse.ArgumentParser()
101
+ ap.add_argument("--wt_meta_csv", type=str, required=True)
102
+ ap.add_argument("--smiles_meta_csv", type=str, required=True)
103
+ ap.add_argument("--out_dir", type=str, required=True)
104
+
105
+ ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
106
+ ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
107
+ ap.add_argument("--affinity_col", type=str, default="affinity")
108
+ args = ap.parse_args()
109
+
110
+ out_dir = Path(args.out_dir)
111
+
112
+ clean_one(
113
+ Path(args.wt_meta_csv),
114
+ out_dir / "wt_iptm_affinity_all.csv",
115
+ iptm_col=args.wt_iptm_col,
116
+ affinity_col=args.affinity_col,
117
+ keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
118
+ )
119
+
120
+ clean_one(
121
+ Path(args.smiles_meta_csv),
122
+ out_dir / "smiles_iptm_affinity_all.csv",
123
+ iptm_col=args.smiles_iptm_col,
124
+ affinity_col=args.affinity_col,
125
+ keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
126
+ )
127
+
128
+ print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
training_classifiers/.ipynb_checkpoints/binding_affinity_split-checkpoint.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import math
4
+ from pathlib import Path
5
+ import sys
6
+ from contextlib import contextmanager
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+
12
+ # tqdm is optional; we’ll disable it by default in notebooks
13
+ from tqdm import tqdm
14
+
15
+ sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
16
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
+
18
+ from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
19
+ from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
20
+
21
+ # -------------------------
22
+ # Config
23
+ # -------------------------
24
+ CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
25
+
26
+ OUT_ROOT = Path(
27
+ "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
28
+ )
29
+
30
+ # WT (seq) embedding model
31
+ WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
32
+ WT_MAX_LEN = 1022
33
+ WT_BATCH = 32
34
+
35
+ # SMILES embedding model + tokenizer
36
+ SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
37
+ TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
38
+ TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
39
+ SMI_MAX_LEN = 768
40
+ SMI_BATCH = 128
41
+
42
+ # Split config
43
+ TRAIN_FRAC = 0.80
44
+ RANDOM_SEED = 1986
45
+ AFFINITY_Q_BINS = 30
46
+
47
+ # Columns expected in CSV
48
+ COL_SEQ1 = "seq1"
49
+ COL_SEQ2 = "seq2"
50
+ COL_AFF = "affinity"
51
+ COL_F2S = "Fasta2SMILES"
52
+ COL_REACT = "REACT_SMILES"
53
+ COL_WT_IPTM = "wt_iptm_score"
54
+ COL_SMI_IPTM = "smiles_iptm_score"
55
+
56
+ # Device
57
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+
59
+ # -------------------------
60
+ # Quiet / notebook-safe output controls
61
+ # -------------------------
62
+ QUIET = True # suppress most prints
63
+ USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
64
+ LOG_FILE = None # optionally: OUT_ROOT / "build.log"
65
+
66
+ def log(msg: str):
67
+ if LOG_FILE is not None:
68
+ Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
69
+ with open(LOG_FILE, "a") as f:
70
+ f.write(msg.rstrip() + "\n")
71
+ if not QUIET:
72
+ print(msg)
73
+
74
+ def pbar(it, **kwargs):
75
+ return tqdm(it, **kwargs) if USE_TQDM else it
76
+
77
+ @contextmanager
78
+ def section(title: str):
79
+ log(f"\n=== {title} ===")
80
+ yield
81
+ log(f"=== done: {title} ===")
82
+
83
+
84
+ # -------------------------
85
+ # Helpers
86
+ # -------------------------
87
+ def has_uaa(seq: str) -> bool:
88
+ return "X" in str(seq).upper()
89
+
90
+ def affinity_to_class(a: float) -> str:
91
+ # High: >= 9 ; Moderate: [7, 9) ; Low: < 7
92
+ if a >= 9.0:
93
+ return "High"
94
+ elif a >= 7.0:
95
+ return "Moderate"
96
+ else:
97
+ return "Low"
98
+
99
+ def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
100
+ df = df.copy()
101
+
102
+ df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
103
+ df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
104
+
105
+ df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
106
+
107
+ try:
108
+ df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
109
+ strat_col = "aff_bin"
110
+ except Exception:
111
+ df["aff_bin"] = df["affinity_class"]
112
+ strat_col = "aff_bin"
113
+
114
+ rng = np.random.RandomState(RANDOM_SEED)
115
+
116
+ df["split"] = None
117
+ for _, g in df.groupby(strat_col, observed=True):
118
+ idx = g.index.to_numpy()
119
+ rng.shuffle(idx)
120
+ n_train = int(math.floor(len(idx) * TRAIN_FRAC))
121
+ df.loc[idx[:n_train], "split"] = "train"
122
+ df.loc[idx[n_train:], "split"] = "val"
123
+
124
+ df["split"] = df["split"].fillna("train")
125
+ return df
126
+
127
+ def _summ(x):
128
+ x = np.asarray(x, dtype=float)
129
+ x = x[~np.isnan(x)]
130
+ if len(x) == 0:
131
+ return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
132
+ return {
133
+ "n": int(len(x)),
134
+ "mean": float(np.mean(x)),
135
+ "std": float(np.std(x)),
136
+ "p50": float(np.quantile(x, 0.50)),
137
+ "p95": float(np.quantile(x, 0.95)),
138
+ }
139
+
140
+ def _len_stats(seqs):
141
+ lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
142
+ if len(lens) == 0:
143
+ return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
144
+ return {
145
+ "n": int(len(lens)),
146
+ "mean": float(lens.mean()),
147
+ "std": float(lens.std()),
148
+ "p50": float(np.quantile(lens, 0.50)),
149
+ "p95": float(np.quantile(lens, 0.95)),
150
+ }
151
+
152
+ def verify_split_before_embedding(
153
+ df2: pd.DataFrame,
154
+ affinity_col: str,
155
+ split_col: str,
156
+ seq_col: str,
157
+ iptm_col: str,
158
+ aff_class_col: str = "affinity_class",
159
+ aff_bins: int = 30,
160
+ save_report_prefix: str | None = None,
161
+ verbose: bool = False,
162
+ ):
163
+ """
164
+ Notebook-safe: by default prints only ONE line via `log()`.
165
+ Optionally writes CSV reports (stats + class proportions).
166
+ """
167
+ df2 = df2.copy()
168
+ df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
169
+ df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
170
+
171
+ assert split_col in df2.columns, f"Missing split col: {split_col}"
172
+ assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
173
+ assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
174
+
175
+ try:
176
+ df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
177
+ except Exception:
178
+ df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
179
+
180
+ tr = df2[df2[split_col] == "train"].reset_index(drop=True)
181
+ va = df2[df2[split_col] == "val"].reset_index(drop=True)
182
+
183
+ tr_aff = _summ(tr[affinity_col].to_numpy())
184
+ va_aff = _summ(va[affinity_col].to_numpy())
185
+ tr_len = _len_stats(tr[seq_col].tolist())
186
+ va_len = _len_stats(va[seq_col].tolist())
187
+
188
+ # bin drift
189
+ bin_ct = (
190
+ df2.groupby([split_col, "_aff_bin_dbg"])
191
+ .size()
192
+ .groupby(level=0)
193
+ .apply(lambda s: s / s.sum())
194
+ )
195
+ tr_bins = bin_ct.loc["train"]
196
+ va_bins = bin_ct.loc["val"]
197
+ all_bins = tr_bins.index.union(va_bins.index)
198
+ tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
199
+ va_bins = va_bins.reindex(all_bins, fill_value=0.0)
200
+ max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
201
+
202
+ msg = (
203
+ f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
204
+ f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
205
+ f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
206
+ f"max_bin_diff={max_bin_diff:.4f}"
207
+ )
208
+ log(msg)
209
+
210
+ if verbose and (not QUIET):
211
+ class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
212
+ class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
213
+ print("\n[verbose] affinity_class counts:\n", class_ct)
214
+ print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
215
+
216
+ if save_report_prefix is not None:
217
+ out = Path(save_report_prefix)
218
+ out.parent.mkdir(parents=True, exist_ok=True)
219
+
220
+ stats_df = pd.DataFrame([
221
+ {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
222
+ {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
223
+ ])
224
+ class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
225
+ class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
226
+
227
+ stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
228
+ class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
229
+
230
+
231
+ # -------------------------
232
+ # WT pooled (ESM2)
233
+ # -------------------------
234
+ @torch.no_grad()
235
+ def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
236
+ embs = []
237
+ for i in pbar(range(0, len(seqs), batch_size)):
238
+ batch = seqs[i:i + batch_size]
239
+ inputs = tokenizer(
240
+ batch,
241
+ padding=True,
242
+ truncation=True,
243
+ max_length=max_length,
244
+ return_tensors="pt",
245
+ )
246
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
247
+ out = model(**inputs)
248
+ h = out.last_hidden_state # (B, L, H)
249
+
250
+ attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
251
+ summed = (h * attn).sum(dim=1) # (B, H)
252
+ denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
253
+ pooled = (summed / denom).detach().cpu().numpy()
254
+ embs.append(pooled)
255
+
256
+ return np.vstack(embs)
257
+
258
+
259
+ # -------------------------
260
+ # WT unpooled (ESM2)
261
+ # -------------------------
262
+ @torch.no_grad()
263
+ def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
264
+ tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
265
+ tok = {k: v.to(DEVICE) for k, v in tok.items()}
266
+ out = model(**tok)
267
+ h = out.last_hidden_state[0] # (L, H)
268
+ attn = tok["attention_mask"][0].bool() # (L,)
269
+ ids = tok["input_ids"][0]
270
+
271
+ keep = attn.clone()
272
+ if cls_id is not None:
273
+ keep &= (ids != cls_id)
274
+ if eos_id is not None:
275
+ keep &= (ids != eos_id)
276
+
277
+ return h[keep].detach().cpu().to(torch.float16).numpy()
278
+
279
+ def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
280
+ """
281
+ Expects df_split to have:
282
+ - target_sequence (seq1)
283
+ - sequence (binder seq2; WT binder)
284
+ - label, affinity_class, COL_AFF, COL_WT_IPTM
285
+ Saves a dataset where each row contains BOTH:
286
+ - target_embedding (Lt,H), target_attention_mask, target_length
287
+ - binder_embedding (Lb,H), binder_attention_mask, binder_length
288
+ """
289
+ cls_id = tokenizer.cls_token_id
290
+ eos_id = tokenizer.eos_token_id
291
+ H = model.config.hidden_size
292
+
293
+ features = Features({
294
+ "target_sequence": Value("string"),
295
+ "sequence": Value("string"),
296
+ "label": Value("float32"),
297
+ "affinity": Value("float32"),
298
+ "affinity_class": Value("string"),
299
+
300
+ "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
301
+ "target_attention_mask": HFSequence(Value("int8")),
302
+ "target_length": Value("int64"),
303
+
304
+ "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
305
+ "binder_attention_mask": HFSequence(Value("int8")),
306
+ "binder_length": Value("int64"),
307
+
308
+ COL_WT_IPTM: Value("float32"),
309
+ COL_AFF: Value("float32"),
310
+ })
311
+
312
+ def gen_rows(df: pd.DataFrame):
313
+ for r in pbar(df.itertuples(index=False), total=len(df)):
314
+ tgt = str(getattr(r, "target_sequence")).strip()
315
+ bnd = str(getattr(r, "sequence")).strip()
316
+
317
+ y = float(getattr(r, "label"))
318
+ aff = float(getattr(r, COL_AFF))
319
+ acls = str(getattr(r, "affinity_class"))
320
+
321
+ iptm = getattr(r, COL_WT_IPTM)
322
+ iptm = float(iptm) if pd.notna(iptm) else np.nan
323
+
324
+ # token embeddings for target + binder (both ESM)
325
+ t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
326
+ b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
327
+
328
+ t_list = t_emb.tolist()
329
+ b_list = b_emb.tolist()
330
+ Lt = len(t_list)
331
+ Lb = len(b_list)
332
+
333
+ yield {
334
+ "target_sequence": tgt,
335
+ "sequence": bnd,
336
+ "label": np.float32(y),
337
+ "affinity": np.float32(aff),
338
+ "affinity_class": acls,
339
+
340
+ "target_embedding": t_list,
341
+ "target_attention_mask": [1] * Lt,
342
+ "target_length": int(Lt),
343
+
344
+ "binder_embedding": b_list,
345
+ "binder_attention_mask": [1] * Lb,
346
+ "binder_length": int(Lb),
347
+
348
+ COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
349
+ COL_AFF: np.float32(aff),
350
+ }
351
+
352
+ out_dir.mkdir(parents=True, exist_ok=True)
353
+ ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
354
+ ds.save_to_disk(str(out_dir), max_shard_size="1GB")
355
+ return ds
356
+
357
+ def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
358
+ smi_tok, smi_roformer):
359
+ """
360
+ df_split must have:
361
+ - target_sequence (seq1)
362
+ - sequence (binder smiles string)
363
+ - label, affinity_class, COL_AFF, COL_SMI_IPTM
364
+ Saves rows with:
365
+ target_embedding (Lt,Ht) from ESM
366
+ binder_embedding (Lb,Hb) from PeptideCLM
367
+ """
368
+ cls_id = wt_tokenizer.cls_token_id
369
+ eos_id = wt_tokenizer.eos_token_id
370
+ Ht = wt_model_unpooled.config.hidden_size
371
+
372
+ # Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
373
+ # Here: we’ll infer from model config if available.
374
+ Hb = getattr(smi_roformer.config, "hidden_size", None)
375
+ if Hb is None:
376
+ Hb = getattr(smi_roformer.config, "dim", None)
377
+ if Hb is None:
378
+ raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
379
+
380
+ features = Features({
381
+ "target_sequence": Value("string"),
382
+ "sequence": Value("string"),
383
+ "label": Value("float32"),
384
+ "affinity": Value("float32"),
385
+ "affinity_class": Value("string"),
386
+
387
+ "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
388
+ "target_attention_mask": HFSequence(Value("int8")),
389
+ "target_length": Value("int64"),
390
+
391
+ "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
392
+ "binder_attention_mask": HFSequence(Value("int8")),
393
+ "binder_length": Value("int64"),
394
+
395
+ COL_SMI_IPTM: Value("float32"),
396
+ COL_AFF: Value("float32"),
397
+ })
398
+
399
+ def gen_rows(df: pd.DataFrame):
400
+ for r in pbar(df.itertuples(index=False), total=len(df)):
401
+ tgt = str(getattr(r, "target_sequence")).strip()
402
+ bnd = str(getattr(r, "sequence")).strip()
403
+
404
+ y = float(getattr(r, "label"))
405
+ aff = float(getattr(r, COL_AFF))
406
+ acls = str(getattr(r, "affinity_class"))
407
+
408
+ iptm = getattr(r, COL_SMI_IPTM)
409
+ iptm = float(iptm) if pd.notna(iptm) else np.nan
410
+
411
+ # target token embeddings (ESM)
412
+ t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
413
+ t_list = t_emb.tolist()
414
+ Lt = len(t_list)
415
+
416
+ # binder token embeddings (PeptideCLM) — single-item batch
417
+ _, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
418
+ [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
419
+ )
420
+ b_emb = tok_list[0] # np.float16 (Lb, Hb)
421
+ b_list = b_emb.tolist()
422
+ Lb = int(lengths[0])
423
+ b_mask = mask_list[0].astype(np.int8).tolist()
424
+
425
+ yield {
426
+ "target_sequence": tgt,
427
+ "sequence": bnd,
428
+ "label": np.float32(y),
429
+ "affinity": np.float32(aff),
430
+ "affinity_class": acls,
431
+
432
+ "target_embedding": t_list,
433
+ "target_attention_mask": [1] * Lt,
434
+ "target_length": int(Lt),
435
+
436
+ "binder_embedding": b_list,
437
+ "binder_attention_mask": [int(x) for x in b_mask],
438
+ "binder_length": int(Lb),
439
+
440
+ COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
441
+ COL_AFF: np.float32(aff),
442
+ }
443
+
444
+ out_dir.mkdir(parents=True, exist_ok=True)
445
+ ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
446
+ ds.save_to_disk(str(out_dir), max_shard_size="1GB")
447
+ return ds
448
+
449
+
450
+ # -------------------------
451
+ # SMILES pooled + unpooled (PeptideCLM)
452
+ # -------------------------
453
+ def get_special_ids(tokenizer_obj):
454
+ cand = [
455
+ getattr(tokenizer_obj, "pad_token_id", None),
456
+ getattr(tokenizer_obj, "cls_token_id", None),
457
+ getattr(tokenizer_obj, "sep_token_id", None),
458
+ getattr(tokenizer_obj, "bos_token_id", None),
459
+ getattr(tokenizer_obj, "eos_token_id", None),
460
+ getattr(tokenizer_obj, "mask_token_id", None),
461
+ ]
462
+ return sorted({x for x in cand if x is not None})
463
+
464
+ @torch.no_grad()
465
+ def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
466
+ tok = tokenizer_obj(
467
+ batch_sequences,
468
+ return_tensors="pt",
469
+ padding=True,
470
+ truncation=True,
471
+ max_length=max_length,
472
+ )
473
+ input_ids = tok["input_ids"].to(DEVICE)
474
+ attention_mask = tok["attention_mask"].to(DEVICE)
475
+
476
+ outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
477
+ last_hidden = outputs.last_hidden_state # (B, L, H)
478
+
479
+ special_ids = get_special_ids(tokenizer_obj)
480
+ valid = attention_mask.bool()
481
+ if len(special_ids) > 0:
482
+ sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
483
+ if hasattr(torch, "isin"):
484
+ valid = valid & (~torch.isin(input_ids, sid))
485
+ else:
486
+ m = torch.zeros_like(valid)
487
+ for s in special_ids:
488
+ m |= (input_ids == s)
489
+ valid = valid & (~m)
490
+
491
+ valid_f = valid.unsqueeze(-1).float()
492
+ summed = torch.sum(last_hidden * valid_f, dim=1)
493
+ denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
494
+ pooled = (summed / denom).detach().cpu().numpy()
495
+
496
+ token_emb_list, mask_list, lengths = [], [], []
497
+ for b in range(last_hidden.shape[0]):
498
+ emb = last_hidden[b, valid[b]] # (Li, H)
499
+ token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
500
+ li = emb.shape[0]
501
+ lengths.append(int(li))
502
+ mask_list.append(np.ones((li,), dtype=np.int8))
503
+
504
+ return pooled, token_emb_list, mask_list, lengths
505
+
506
+ def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
507
+ pooled_all = []
508
+ token_emb_all = []
509
+ mask_all = []
510
+ lengths_all = []
511
+
512
+ for i in pbar(range(0, len(seqs), batch_size)):
513
+ batch = seqs[i:i + batch_size]
514
+ pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
515
+ batch, tokenizer_obj, model_roformer, max_length
516
+ )
517
+ pooled_all.append(pooled)
518
+ token_emb_all.extend(tok_list)
519
+ mask_all.extend(m_list)
520
+ lengths_all.extend(lens)
521
+
522
+ return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
523
+
524
+ # -------------------------
525
+ # Target embedding cache (NO extra ESM runs)
526
+ # We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
527
+ # -------------------------
528
+ def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
529
+ wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
530
+ wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
531
+
532
+ # compute target pooled embeddings once
533
+ tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
534
+ tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
535
+
536
+ wt_train_tgt_emb = wt_pooled_embeddings(
537
+ tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
538
+ )
539
+ wt_val_tgt_emb = wt_pooled_embeddings(
540
+ tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
541
+ )
542
+
543
+ # build dict: target_sequence -> embedding (float32 array)
544
+ # if duplicates exist, last wins; you can add checks if needed
545
+ train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
546
+ val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
547
+ return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
548
+ # -------------------------
549
+ # Main
550
+ # -------------------------
551
+ def main():
552
+ log(f"[INFO] DEVICE: {DEVICE}")
553
+ OUT_ROOT.mkdir(parents=True, exist_ok=True)
554
+
555
+ # 1) Load
556
+ with section("load csv + dedup"):
557
+ df = pd.read_csv(CSV_PATH)
558
+ for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
559
+ if c in df.columns:
560
+ df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
561
+
562
+ # Dedup on the full identity tuple you want
563
+ DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
564
+ df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
565
+
566
+ print("Rows after dedup on", DEDUP_COLS, ":", len(df))
567
+
568
+ need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
569
+ missing = [c for c in need if c not in df.columns]
570
+ if missing:
571
+ raise ValueError(f"Missing required columns: {missing}")
572
+
573
+ # numeric affinity for both branches
574
+ df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
575
+
576
+ # 2) Build WT subset + SMILES subset separately (NO global dropping)
577
+ with section("prepare wt/smiles subsets"):
578
+ # WT: requires a canonical peptide sequence (no X) + affinity
579
+ df_wt = df.copy()
580
+ df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
581
+ df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
582
+ df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
583
+ df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
584
+
585
+ # SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
586
+ df_smi = df.copy()
587
+ df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
588
+ df_smi = df_smi[
589
+ pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
590
+ ].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
591
+
592
+ is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
593
+ df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
594
+ df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
595
+ df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
596
+ df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
597
+
598
+ log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
599
+
600
+ # 3) Split separately (different sizes and memberships are expected)
601
+ with section("split wt and smiles separately"):
602
+ df_wt2 = make_distribution_matched_split(df_wt)
603
+ df_smi2 = make_distribution_matched_split(df_smi)
604
+
605
+ # save split tables
606
+ wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
607
+ smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
608
+ df_wt2.to_csv(wt_split_csv, index=False)
609
+ df_smi2.to_csv(smi_split_csv, index=False)
610
+ log(f"Saved WT split meta: {wt_split_csv}")
611
+ log(f"Saved SMILES split meta: {smi_split_csv}")
612
+
613
+ # lightweight double-check (one-line)
614
+ verify_split_before_embedding(
615
+ df2=df_wt2,
616
+ affinity_col=COL_AFF,
617
+ split_col="split",
618
+ seq_col="wt_sequence",
619
+ iptm_col=COL_WT_IPTM,
620
+ aff_class_col="affinity_class",
621
+ aff_bins=AFFINITY_Q_BINS,
622
+ save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
623
+ verbose=False,
624
+ )
625
+ verify_split_before_embedding(
626
+ df2=df_smi2,
627
+ affinity_col=COL_AFF,
628
+ split_col="split",
629
+ seq_col="smiles_sequence",
630
+ iptm_col=COL_SMI_IPTM,
631
+ aff_class_col="affinity_class",
632
+ aff_bins=AFFINITY_Q_BINS,
633
+ save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
634
+ verbose=False,
635
+ )
636
+
637
+ # Prepare split views
638
+ def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
639
+ out = df_in.copy()
640
+ out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
641
+ out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
642
+ out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
643
+ out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
644
+ out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
645
+ out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
646
+ return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
647
+
648
+ wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
649
+ smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
650
+
651
+ # -------------------------
652
+ # Split views
653
+ # -------------------------
654
+ wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
655
+ wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
656
+ smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
657
+ smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
658
+
659
+
660
+ # =========================
661
+ # TARGET pooled embeddings (ESM) — SEPARATE per branch
662
+ # =========================
663
+ with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
664
+ wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
665
+ wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
666
+
667
+ # ---- WT targets ----
668
+ wt_train_tgt_emb = wt_pooled_embeddings(
669
+ wt_train["target_sequence"].astype(str).str.strip().tolist(),
670
+ wt_tok, wt_esm,
671
+ batch_size=WT_BATCH,
672
+ max_length=WT_MAX_LEN,
673
+ ).astype(np.float32)
674
+
675
+ wt_val_tgt_emb = wt_pooled_embeddings(
676
+ wt_val["target_sequence"].astype(str).str.strip().tolist(),
677
+ wt_tok, wt_esm,
678
+ batch_size=WT_BATCH,
679
+ max_length=WT_MAX_LEN,
680
+ ).astype(np.float32)
681
+
682
+ # ---- SMILES targets (independent; may include UAA-only targets) ----
683
+ smi_train_tgt_emb = wt_pooled_embeddings(
684
+ smi_train["target_sequence"].astype(str).str.strip().tolist(),
685
+ wt_tok, wt_esm,
686
+ batch_size=WT_BATCH,
687
+ max_length=WT_MAX_LEN,
688
+ ).astype(np.float32)
689
+
690
+ smi_val_tgt_emb = wt_pooled_embeddings(
691
+ smi_val["target_sequence"].astype(str).str.strip().tolist(),
692
+ wt_tok, wt_esm,
693
+ batch_size=WT_BATCH,
694
+ max_length=WT_MAX_LEN,
695
+ ).astype(np.float32)
696
+
697
+
698
+ # =========================
699
+ # WT pooled binder embeddings (binder = WT peptide)
700
+ # =========================
701
+ with section("WT pooled binder embeddings + save"):
702
+ wt_train_emb = wt_pooled_embeddings(
703
+ wt_train["sequence"].astype(str).str.strip().tolist(),
704
+ wt_tok, wt_esm,
705
+ batch_size=WT_BATCH,
706
+ max_length=WT_MAX_LEN,
707
+ ).astype(np.float32)
708
+
709
+ wt_val_emb = wt_pooled_embeddings(
710
+ wt_val["sequence"].astype(str).str.strip().tolist(),
711
+ wt_tok, wt_esm,
712
+ batch_size=WT_BATCH,
713
+ max_length=WT_MAX_LEN,
714
+ ).astype(np.float32)
715
+
716
+ wt_train_ds = Dataset.from_dict({
717
+ "target_sequence": wt_train["target_sequence"].tolist(),
718
+ "sequence": wt_train["sequence"].tolist(),
719
+ "label": wt_train["label"].astype(float).tolist(),
720
+ "target_embedding": wt_train_tgt_emb,
721
+ "embedding": wt_train_emb,
722
+ COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
723
+ COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
724
+ "affinity_class": wt_train["affinity_class"].tolist(),
725
+ })
726
+
727
+ wt_val_ds = Dataset.from_dict({
728
+ "target_sequence": wt_val["target_sequence"].tolist(),
729
+ "sequence": wt_val["sequence"].tolist(),
730
+ "label": wt_val["label"].astype(float).tolist(),
731
+ "target_embedding": wt_val_tgt_emb,
732
+ "embedding": wt_val_emb,
733
+ COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
734
+ COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
735
+ "affinity_class": wt_val["affinity_class"].tolist(),
736
+ })
737
+
738
+ wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
739
+ wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
740
+ wt_pooled_dd.save_to_disk(str(wt_pooled_out))
741
+ log(f"Saved WT pooled -> {wt_pooled_out}")
742
+
743
+
744
+ # =========================
745
+ # SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
746
+ # =========================
747
+ with section("SMILES pooled binder embeddings + save"):
748
+ smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
749
+ smi_roformer = (
750
+ AutoModelForMaskedLM
751
+ .from_pretrained(SMI_MODEL_NAME)
752
+ .roformer
753
+ .to(DEVICE)
754
+ .eval()
755
+ )
756
+
757
+ smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
758
+ smi_train["sequence"].astype(str).str.strip().tolist(),
759
+ smi_tok, smi_roformer,
760
+ batch_size=SMI_BATCH,
761
+ max_length=SMI_MAX_LEN,
762
+ )
763
+
764
+ smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
765
+ smi_val["sequence"].astype(str).str.strip().tolist(),
766
+ smi_tok, smi_roformer,
767
+ batch_size=SMI_BATCH,
768
+ max_length=SMI_MAX_LEN,
769
+ )
770
+
771
+ smi_train_ds = Dataset.from_dict({
772
+ "target_sequence": smi_train["target_sequence"].tolist(),
773
+ "sequence": smi_train["sequence"].tolist(),
774
+ "label": smi_train["label"].astype(float).tolist(),
775
+ "target_embedding": smi_train_tgt_emb,
776
+ "embedding": smi_train_pooled.astype(np.float32),
777
+ COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
778
+ COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
779
+ "affinity_class": smi_train["affinity_class"].tolist(),
780
+ })
781
+
782
+ smi_val_ds = Dataset.from_dict({
783
+ "target_sequence": smi_val["target_sequence"].tolist(),
784
+ "sequence": smi_val["sequence"].tolist(),
785
+ "label": smi_val["label"].astype(float).tolist(),
786
+ "target_embedding": smi_val_tgt_emb,
787
+ "embedding": smi_val_pooled.astype(np.float32),
788
+ COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
789
+ COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
790
+ "affinity_class": smi_val["affinity_class"].tolist(),
791
+ })
792
+
793
+ smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
794
+ smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
795
+ smi_pooled_dd.save_to_disk(str(smi_pooled_out))
796
+ log(f"Saved SMILES pooled -> {smi_pooled_out}")
797
+
798
+
799
+ # =========================
800
+ # WT unpooled paired (ESM target + ESM binder) + save
801
+ # =========================
802
+ with section("WT unpooled paired embeddings + save"):
803
+ wt_tok_unpooled = wt_tok # reuse tokenizer
804
+ wt_esm_unpooled = wt_esm # reuse model
805
+
806
+ wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
807
+ wt_unpooled_dd = DatasetDict({
808
+ "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
809
+ wt_tok_unpooled, wt_esm_unpooled),
810
+ "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
811
+ wt_tok_unpooled, wt_esm_unpooled),
812
+ })
813
+ # (Optional) also save as DatasetDict root if you want a single load_from_disk path:
814
+ wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
815
+ log(f"Saved WT unpooled -> {wt_unpooled_out}")
816
+
817
+
818
+ # =========================
819
+ # SMILES unpooled paired (ESM target + PeptideCLM binder) + save
820
+ # =========================
821
+ with section("SMILES unpooled paired embeddings + save"):
822
+ # reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
823
+ # otherwise re-init here:
824
+ # smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
825
+ # smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
826
+
827
+ smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
828
+ smi_unpooled_dd = DatasetDict({
829
+ "train": build_smiles_unpooled_paired_dataset(
830
+ smi_train, smi_unpooled_out / "train",
831
+ wt_tok, wt_esm,
832
+ smi_tok, smi_roformer
833
+ ),
834
+ "val": build_smiles_unpooled_paired_dataset(
835
+ smi_val, smi_unpooled_out / "val",
836
+ wt_tok, wt_esm,
837
+ smi_tok, smi_roformer
838
+ ),
839
+ })
840
+ smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
841
+ log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
842
+
843
+ log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
844
+
845
+
846
+ if __name__ == "__main__":
847
+ main()
training_classifiers/.ipynb_checkpoints/binding_training-checkpoint.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ import optuna
8
+ from datasets import load_from_disk, DatasetDict
9
+ from scipy.stats import spearmanr
10
+ from lightning.pytorch import seed_everything
11
+ seed_everything(1986)
12
+
13
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+
15
+
16
+ def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
17
+ rho = spearmanr(y_true, y_pred).correlation
18
+ if rho is None or np.isnan(rho):
19
+ return 0.0
20
+ return float(rho)
21
+
22
+
23
+ # -----------------------------
24
+ # Affinity class thresholds (final spec)
25
+ # High >= 9 ; Moderate 7-9 ; Low < 7
26
+ # 0=High, 1=Moderate, 2=Low
27
+ # -----------------------------
28
+ def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor:
29
+ high = y >= 9.0
30
+ low = y < 7.0
31
+ mid = ~(high | low)
32
+ cls = torch.zeros_like(y, dtype=torch.long)
33
+ cls[mid] = 1
34
+ cls[low] = 2
35
+ return cls
36
+
37
+
38
+ # -----------------------------
39
+ # Load paired DatasetDict
40
+ # -----------------------------
41
+ def load_split_paired(path: str):
42
+ dd = load_from_disk(path)
43
+ if not isinstance(dd, DatasetDict):
44
+ raise ValueError(f"Expected DatasetDict at {path}")
45
+ if "train" not in dd or "val" not in dd:
46
+ raise ValueError(f"DatasetDict missing train/val at {path}")
47
+ return dd["train"], dd["val"]
48
+
49
+
50
+ # -----------------------------
51
+ # Collate: pooled paired
52
+ # -----------------------------
53
+ def collate_pair_pooled(batch):
54
+ Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
55
+ Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
56
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
57
+ return Pt, Pb, y
58
+
59
+
60
+ # -----------------------------
61
+ # Collate: unpooled paired
62
+ # -----------------------------
63
+ def collate_pair_unpooled(batch):
64
+ B = len(batch)
65
+ Ht = len(batch[0]["target_embedding"][0])
66
+ Hb = len(batch[0]["binder_embedding"][0])
67
+ Lt_max = max(int(x["target_length"]) for x in batch)
68
+ Lb_max = max(int(x["binder_length"]) for x in batch)
69
+
70
+ Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
71
+ Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
72
+ Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
73
+ Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
74
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
75
+
76
+ for i, x in enumerate(batch):
77
+ t = torch.tensor(x["target_embedding"], dtype=torch.float32)
78
+ b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
79
+ lt, lb = t.shape[0], b.shape[0]
80
+ Pt[i, :lt] = t
81
+ Pb[i, :lb] = b
82
+ Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
83
+ Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
84
+
85
+ return Pt, Mt, Pb, Mb, y
86
+
87
+
88
+ # -----------------------------
89
+ # Cross-attention models
90
+ # -----------------------------
91
+ class CrossAttnPooled(nn.Module):
92
+ """
93
+ pooled vectors -> treat as single-token sequences for cross attention
94
+ """
95
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
96
+ super().__init__()
97
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
98
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
99
+
100
+ self.layers = nn.ModuleList([])
101
+ for _ in range(n_layers):
102
+ self.layers.append(nn.ModuleDict({
103
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
104
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
105
+ "n1t": nn.LayerNorm(hidden),
106
+ "n2t": nn.LayerNorm(hidden),
107
+ "n1b": nn.LayerNorm(hidden),
108
+ "n2b": nn.LayerNorm(hidden),
109
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
110
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
111
+ }))
112
+
113
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
114
+ self.reg = nn.Linear(hidden, 1)
115
+ self.cls = nn.Linear(hidden, 3)
116
+
117
+ def forward(self, t_vec, b_vec):
118
+ # (B,Ht),(B,Hb)
119
+ t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
120
+ b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
121
+
122
+ for L in self.layers:
123
+ t_attn, _ = L["attn_tb"](t, b, b)
124
+ t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
125
+ t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
126
+
127
+ b_attn, _ = L["attn_bt"](b, t, t)
128
+ b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
129
+ b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
130
+
131
+ t0 = t[0]
132
+ b0 = b[0]
133
+ z = torch.cat([t0, b0], dim=-1)
134
+ h = self.shared(z)
135
+ return self.reg(h).squeeze(-1), self.cls(h)
136
+
137
+
138
+ class CrossAttnUnpooled(nn.Module):
139
+ """
140
+ token sequences with masks; alternating cross attention.
141
+ """
142
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
143
+ super().__init__()
144
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
145
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
146
+
147
+ self.layers = nn.ModuleList([])
148
+ for _ in range(n_layers):
149
+ self.layers.append(nn.ModuleDict({
150
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
151
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
152
+ "n1t": nn.LayerNorm(hidden),
153
+ "n2t": nn.LayerNorm(hidden),
154
+ "n1b": nn.LayerNorm(hidden),
155
+ "n2b": nn.LayerNorm(hidden),
156
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
157
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
158
+ }))
159
+
160
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
161
+ self.reg = nn.Linear(hidden, 1)
162
+ self.cls = nn.Linear(hidden, 3)
163
+
164
+ def masked_mean(self, X, M):
165
+ Mf = M.unsqueeze(-1).float()
166
+ denom = Mf.sum(dim=1).clamp(min=1.0)
167
+ return (X * Mf).sum(dim=1) / denom
168
+
169
+ def forward(self, T, Mt, B, Mb):
170
+ # T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
171
+ T = self.t_proj(T)
172
+ Bx = self.b_proj(B)
173
+
174
+ kp_t = ~Mt # key_padding_mask True = pad
175
+ kp_b = ~Mb
176
+
177
+ for L in self.layers:
178
+ # T attends to B
179
+ T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
180
+ T = L["n1t"](T + T_attn)
181
+ T = L["n2t"](T + L["fft"](T))
182
+
183
+ # B attends to T
184
+ B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
185
+ Bx = L["n1b"](Bx + B_attn)
186
+ Bx = L["n2b"](Bx + L["ffb"](Bx))
187
+
188
+ t_pool = self.masked_mean(T, Mt)
189
+ b_pool = self.masked_mean(Bx, Mb)
190
+ z = torch.cat([t_pool, b_pool], dim=-1)
191
+ h = self.shared(z)
192
+ return self.reg(h).squeeze(-1), self.cls(h)
193
+
194
+
195
+ # -----------------------------
196
+ # Train/eval
197
+ # -----------------------------
198
+ @torch.no_grad()
199
+ def eval_spearman_pooled(model, loader):
200
+ model.eval()
201
+ ys, ps = [], []
202
+ for t, b, y in loader:
203
+ t = t.to(DEVICE, non_blocking=True)
204
+ b = b.to(DEVICE, non_blocking=True)
205
+ pred, _ = model(t, b)
206
+ ys.append(y.numpy())
207
+ ps.append(pred.detach().cpu().numpy())
208
+ return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
209
+
210
+ @torch.no_grad()
211
+ def eval_spearman_unpooled(model, loader):
212
+ model.eval()
213
+ ys, ps = [], []
214
+ for T, Mt, B, Mb, y in loader:
215
+ T = T.to(DEVICE, non_blocking=True)
216
+ Mt = Mt.to(DEVICE, non_blocking=True)
217
+ B = B.to(DEVICE, non_blocking=True)
218
+ Mb = Mb.to(DEVICE, non_blocking=True)
219
+ pred, _ = model(T, Mt, B, Mb)
220
+ ys.append(y.numpy())
221
+ ps.append(pred.detach().cpu().numpy())
222
+ return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
223
+
224
+ def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
225
+ model.train()
226
+ for t, b, y in loader:
227
+ t = t.to(DEVICE, non_blocking=True)
228
+ b = b.to(DEVICE, non_blocking=True)
229
+ y = y.to(DEVICE, non_blocking=True)
230
+ y_cls = affinity_to_class_tensor(y)
231
+
232
+ opt.zero_grad(set_to_none=True)
233
+ pred, logits = model(t, b)
234
+ L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
235
+ L.backward()
236
+ if clip is not None:
237
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
238
+ opt.step()
239
+
240
+ def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
241
+ model.train()
242
+ for T, Mt, B, Mb, y in loader:
243
+ T = T.to(DEVICE, non_blocking=True)
244
+ Mt = Mt.to(DEVICE, non_blocking=True)
245
+ B = B.to(DEVICE, non_blocking=True)
246
+ Mb = Mb.to(DEVICE, non_blocking=True)
247
+ y = y.to(DEVICE, non_blocking=True)
248
+ y_cls = affinity_to_class_tensor(y)
249
+
250
+ opt.zero_grad(set_to_none=True)
251
+ pred, logits = model(T, Mt, B, Mb)
252
+ L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
253
+ L.backward()
254
+ if clip is not None:
255
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
256
+ opt.step()
257
+
258
+
259
+ # -----------------------------
260
+ # Optuna objective
261
+ # -----------------------------
262
+ def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float:
263
+ lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
264
+ wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
265
+ dropout = trial.suggest_float("dropout", 0.0, 0.4)
266
+ hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768])
267
+ n_heads = trial.suggest_categorical("n_heads", [4, 8])
268
+ n_layers = trial.suggest_int("n_layers", 1, 4)
269
+ cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True)
270
+ batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
271
+
272
+ # infer dims from first row
273
+ if mode == "pooled":
274
+ Ht = len(train_ds[0]["target_embedding"])
275
+ Hb = len(train_ds[0]["binder_embedding"])
276
+ collate = collate_pair_pooled
277
+ model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
278
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
279
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
280
+ eval_fn = eval_spearman_pooled
281
+ train_fn = train_one_epoch_pooled
282
+
283
+ else:
284
+ Ht = len(train_ds[0]["target_embedding"][0])
285
+ Hb = len(train_ds[0]["binder_embedding"][0])
286
+ collate = collate_pair_unpooled
287
+ model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
288
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
289
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
290
+ eval_fn = eval_spearman_unpooled
291
+ train_fn = train_one_epoch_unpooled
292
+
293
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
294
+ loss_reg = nn.MSELoss()
295
+ loss_cls = nn.CrossEntropyLoss()
296
+
297
+ best = -1e9
298
+ bad = 0
299
+ patience = 10
300
+
301
+ for ep in range(1, 61):
302
+ train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
303
+ rho = eval_fn(model, val_loader)
304
+
305
+ trial.report(rho, ep)
306
+ if trial.should_prune():
307
+ raise optuna.TrialPruned()
308
+
309
+ if rho > best + 1e-6:
310
+ best = rho
311
+ bad = 0
312
+ else:
313
+ bad += 1
314
+ if bad >= patience:
315
+ break
316
+
317
+ return float(best)
318
+
319
+
320
+ # -----------------------------
321
+ # Run: optuna + refit best
322
+ # -----------------------------
323
+ def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
324
+ out_dir = Path(out_dir)
325
+ out_dir.mkdir(parents=True, exist_ok=True)
326
+
327
+ train_ds, val_ds = load_split_paired(dataset_path)
328
+ print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}")
329
+
330
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
331
+ study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials)
332
+
333
+ study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False)
334
+ best = study.best_trial
335
+ best_params = dict(best.params)
336
+
337
+ # refit longer
338
+ lr = float(best_params["lr"])
339
+ wd = float(best_params["weight_decay"])
340
+ dropout = float(best_params["dropout"])
341
+ hidden = int(best_params["hidden_dim"])
342
+ n_heads = int(best_params["n_heads"])
343
+ n_layers = int(best_params["n_layers"])
344
+ cls_w = float(best_params["cls_weight"])
345
+ batch = int(best_params["batch_size"])
346
+
347
+ loss_reg = nn.MSELoss()
348
+ loss_cls = nn.CrossEntropyLoss()
349
+
350
+ if mode == "pooled":
351
+ Ht = len(train_ds[0]["target_embedding"])
352
+ Hb = len(train_ds[0]["binder_embedding"])
353
+ model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
354
+ collate = collate_pair_pooled
355
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
356
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
357
+ eval_fn = eval_spearman_pooled
358
+ train_fn = train_one_epoch_pooled
359
+ else:
360
+ Ht = len(train_ds[0]["target_embedding"][0])
361
+ Hb = len(train_ds[0]["binder_embedding"][0])
362
+ model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
363
+ collate = collate_pair_unpooled
364
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
365
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
366
+ eval_fn = eval_spearman_unpooled
367
+ train_fn = train_one_epoch_unpooled
368
+
369
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
370
+
371
+ best_rho = -1e9
372
+ bad = 0
373
+ patience = 20
374
+ best_state = None
375
+
376
+ for ep in range(1, 201):
377
+ train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
378
+ rho = eval_fn(model, val_loader)
379
+
380
+ if rho > best_rho + 1e-6:
381
+ best_rho = rho
382
+ bad = 0
383
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
384
+ else:
385
+ bad += 1
386
+ if bad >= patience:
387
+ break
388
+
389
+ if best_state is not None:
390
+ model.load_state_dict(best_state)
391
+
392
+ # save
393
+ torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt")
394
+ with open(out_dir / "best_params.json", "w") as f:
395
+ json.dump(best_params, f, indent=2)
396
+
397
+ print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}")
398
+
399
+
400
+ if __name__ == "__main__":
401
+ import argparse
402
+ ap = argparse.ArgumentParser()
403
+ ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)")
404
+ ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True)
405
+ ap.add_argument("--out_dir", type=str, required=True)
406
+ ap.add_argument("--n_trials", type=int, default=50)
407
+ args = ap.parse_args()
408
+
409
+ run(
410
+ dataset_path=args.dataset_path,
411
+ out_dir=args.out_dir,
412
+ mode=args.mode,
413
+ n_trials=args.n_trials,
414
+ )
training_classifiers/.ipynb_checkpoints/binding_wt-checkpoint.bash ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=b-data
3
+ #SBATCH --partition=dgx-b200
4
+ #SBATCH --gpus=1
5
+ #SBATCH --cpus-per-task=10
6
+ #SBATCH --mem=100G
7
+ #SBATCH --time=48:00:00
8
+ #SBATCH --output=%x_%j.out
9
+
10
+ HOME_LOC=/vast/projects/pranam/lab/yz927
11
+ SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
12
+ DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
13
+ OBJECTIVE='binding_affinity'
14
+ WT='smiles' #wt/smiles
15
+ STATUS='pooled' #pooled/unpooled
16
+ DATA_FILE="pair_wt_${WT}_${STATUS}"
17
+ LOG_LOC=$SCRIPT_LOC
18
+ DATE=$(date +%m_%d)
19
+ SPECIAL_PREFIX="binding_affinity_data_generation"
20
+
21
+ # Create log directory if it doesn't exist
22
+ mkdir -p $LOG_LOC
23
+
24
+ cd $SCRIPT_LOC
25
+ source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
26
+ conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
27
+
28
+ python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
29
+
30
+ echo "Script completed at $(date)"
31
+ conda deactivate
training_classifiers/.ipynb_checkpoints/finetune_boost-checkpoint.py ADDED
@@ -0,0 +1,508 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # finetune_xgb_halflife_cv_optuna.py
3
+
4
+ import os
5
+ import json
6
+ import math
7
+ import hashlib
8
+ from dataclasses import dataclass
9
+ from typing import Dict, Any, Optional, Tuple, List
10
+
11
+ import numpy as np
12
+ import pandas as pd
13
+ import optuna
14
+
15
+ from sklearn.model_selection import KFold
16
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
17
+ from scipy.stats import spearmanr
18
+
19
+ import torch
20
+ from transformers import AutoTokenizer, AutoModel
21
+
22
+ import xgboost as xgb
23
+
24
+
25
+ # -----------------------------
26
+ # Repro
27
+ # -----------------------------
28
+ SEED = 1986
29
+ np.random.seed(SEED)
30
+ torch.manual_seed(SEED)
31
+
32
+
33
+ # -----------------------------
34
+ # Metrics (mirrors your stability script style)
35
+ # -----------------------------
36
+ def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
37
+ rho = spearmanr(y_true, y_pred).correlation
38
+ if rho is None or np.isnan(rho):
39
+ return 0.0
40
+ return float(rho)
41
+
42
+ def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
43
+ rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
44
+ mae = float(mean_absolute_error(y_true, y_pred))
45
+ r2 = float(r2_score(y_true, y_pred))
46
+ rho = float(safe_spearmanr(y_true, y_pred))
47
+ return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
48
+
49
+
50
+ # -----------------------------
51
+ # ESM-2 embeddings (cached)
52
+ # -----------------------------
53
+ @dataclass
54
+ class ESMEmbedderConfig:
55
+ model_name: str = "facebook/esm2_t33_650M_UR50D"
56
+ batch_size: int = 8
57
+ max_length: int = 1024 # truncate very long proteins
58
+ fp16: bool = True
59
+
60
+ class ESM2Embedder:
61
+ """
62
+ Mean-pooled last hidden state (excluding special tokens) -> (H,) per sequence.
63
+ """
64
+ def __init__(self, cfg: ESMEmbedderConfig, device: str = "cuda"):
65
+ self.cfg = cfg
66
+ self.device = device if (device == "cuda" and torch.cuda.is_available()) else "cpu"
67
+ self.tokenizer = AutoTokenizer.from_pretrained(cfg.model_name, do_lower_case=False)
68
+ self.model = AutoModel.from_pretrained(cfg.model_name)
69
+ self.model.eval()
70
+ self.model.to(self.device)
71
+
72
+ # Turn off gradients
73
+ for p in self.model.parameters():
74
+ p.requires_grad = False
75
+
76
+ @torch.inference_mode()
77
+ def embed(self, seqs: List[str]) -> np.ndarray:
78
+ out = []
79
+ bs = self.cfg.batch_size
80
+
81
+ use_amp = (self.cfg.fp16 and self.device == "cuda")
82
+ autocast = torch.cuda.amp.autocast if use_amp else torch.cpu.amp.autocast # safe fallback
83
+
84
+ for i in range(0, len(seqs), bs):
85
+ batch = [s.strip().upper() for s in seqs[i:i+bs]]
86
+ toks = self.tokenizer(
87
+ batch,
88
+ return_tensors="pt",
89
+ padding=True,
90
+ truncation=True,
91
+ max_length=self.cfg.max_length,
92
+ add_special_tokens=True,
93
+ )
94
+ toks = {k: v.to(self.device) for k, v in toks.items()}
95
+ attn = toks["attention_mask"] # (B, L)
96
+
97
+ with autocast(enabled=use_amp):
98
+ h = self.model(**toks).last_hidden_state # (B, L, H)
99
+
100
+ # mask out special tokens: first token is <cls>; last non-pad token is usually <eos>
101
+ mask = attn.clone()
102
+ mask[:, 0] = 0
103
+ lengths = attn.sum(dim=1) # includes special tokens
104
+ # zero out last real token position per sequence
105
+ eos_pos = (lengths - 1).clamp(min=0)
106
+ mask[torch.arange(mask.size(0), device=mask.device), eos_pos] = 0
107
+
108
+ denom = mask.sum(dim=1).clamp(min=1).unsqueeze(-1) # (B,1)
109
+ pooled = (h * mask.unsqueeze(-1)).sum(dim=1) / denom # (B,H)
110
+ out.append(pooled.float().detach().cpu().numpy())
111
+
112
+ return np.concatenate(out, axis=0).astype(np.float32)
113
+
114
+
115
+ def dataset_fingerprint(seqs: List[str], y: np.ndarray, extra: str = "") -> str:
116
+ h = hashlib.sha256()
117
+ for s in seqs:
118
+ h.update(s.encode("utf-8"))
119
+ h.update(b"\n")
120
+ h.update(np.asarray(y, dtype=np.float32).tobytes())
121
+ h.update(extra.encode("utf-8"))
122
+ return h.hexdigest()[:16]
123
+
124
+
125
+ def load_or_compute_embeddings(
126
+ df: pd.DataFrame,
127
+ out_dir: str,
128
+ embed_cfg: ESMEmbedderConfig,
129
+ device: str,
130
+ ) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
131
+ os.makedirs(out_dir, exist_ok=True)
132
+
133
+ seqs = df["sequence"].astype(str).tolist()
134
+ y = df["half_life_hours"].astype(float).to_numpy(dtype=np.float32)
135
+
136
+ fp = dataset_fingerprint(seqs, y, extra=f"{embed_cfg.model_name}|{embed_cfg.max_length}")
137
+ emb_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.npy")
138
+ meta_path = os.path.join(out_dir, f"esm2_embeddings_{fp}.json")
139
+
140
+ if os.path.exists(emb_path) and os.path.exists(meta_path):
141
+ X = np.load(emb_path).astype(np.float32)
142
+ return X, y, np.asarray(seqs)
143
+
144
+ embedder = ESM2Embedder(embed_cfg, device=device)
145
+ X = embedder.embed(seqs) # (N,H)
146
+
147
+ np.save(emb_path, X)
148
+ with open(meta_path, "w") as f:
149
+ json.dump(
150
+ {
151
+ "fingerprint": fp,
152
+ "model_name": embed_cfg.model_name,
153
+ "max_length": embed_cfg.max_length,
154
+ "n": len(seqs),
155
+ "dim": int(X.shape[1]),
156
+ },
157
+ f,
158
+ indent=2,
159
+ )
160
+ return X, y, np.asarray(seqs)
161
+
162
+
163
+ # -----------------------------
164
+ # XGBoost training (supports "finetune" via xgb_model)
165
+ # -----------------------------
166
+ def train_xgb_reg(
167
+ X_train: np.ndarray,
168
+ y_train: np.ndarray,
169
+ X_val: np.ndarray,
170
+ y_val: np.ndarray,
171
+ params: Dict[str, Any],
172
+ base_model_json: Optional[str] = None,
173
+ ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray, int]:
174
+ dtrain = xgb.DMatrix(X_train, label=y_train)
175
+ dval = xgb.DMatrix(X_val, label=y_val)
176
+
177
+ num_boost_round = int(params.pop("num_boost_round"))
178
+ early_stopping_rounds = int(params.pop("early_stopping_rounds"))
179
+
180
+ # Important: load a fresh base model each fold (avoid leakage)
181
+ xgb_model = None
182
+ if base_model_json is not None:
183
+ booster0 = xgb.Booster()
184
+ booster0.load_model(base_model_json)
185
+ xgb_model = booster0
186
+
187
+ booster = xgb.train(
188
+ params=params,
189
+ dtrain=dtrain,
190
+ num_boost_round=num_boost_round,
191
+ evals=[(dval, "val")],
192
+ early_stopping_rounds=early_stopping_rounds,
193
+ verbose_eval=False,
194
+ xgb_model=xgb_model, # <-- "finetune": continue boosting from base model
195
+ )
196
+
197
+ p_train = booster.predict(dtrain)
198
+ p_val = booster.predict(dval)
199
+ best_iter = int(getattr(booster, "best_iteration", num_boost_round - 1))
200
+ return booster, p_train, p_val, best_iter
201
+
202
+
203
+ # -----------------------------
204
+ # Optuna objective: 5-fold mean Spearman rho
205
+ # -----------------------------
206
+ def make_cv_objective(
207
+ X: np.ndarray,
208
+ y: np.ndarray,
209
+ n_splits: int,
210
+ device: str,
211
+ base_model_json: Optional[str],
212
+ target_transform: str,
213
+ ):
214
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
215
+
216
+ # Optional target transform (sometimes helps with heavy-tailed half-life)
217
+ if target_transform == "log1p":
218
+ y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
219
+ elif target_transform == "none":
220
+ y_used = y.astype(np.float32)
221
+ else:
222
+ raise ValueError(f"Unknown target_transform: {target_transform}")
223
+
224
+ def objective(trial: optuna.Trial) -> float:
225
+ # Hyperparam ranges patterned after your stability script :contentReference[oaicite:1]{index=1}
226
+ params = {
227
+ "objective": "reg:squarederror",
228
+ "eval_metric": "rmse",
229
+
230
+ "lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
231
+ "alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
232
+ "gamma": trial.suggest_float("gamma", 0.0, 10.0),
233
+
234
+ "max_depth": trial.suggest_int("max_depth", 2, 12),
235
+ "min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 200.0, log=True),
236
+ "subsample": trial.suggest_float("subsample", 0.5, 1.0),
237
+ "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
238
+
239
+ "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.2, log=True),
240
+
241
+ "tree_method": "hist",
242
+ "device": "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu",
243
+ }
244
+ params["num_boost_round"] = trial.suggest_int("num_boost_round", 30, 1500)
245
+ params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 10, 150)
246
+
247
+ fold_metrics = []
248
+ fold_best_iters = []
249
+
250
+ for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
251
+ Xtr, ytr = X[tr_idx], y_used[tr_idx]
252
+ Xva, yva = X[va_idx], y_used[va_idx]
253
+
254
+ _, _, p_va, best_iter = train_xgb_reg(
255
+ Xtr, ytr, Xva, yva, params.copy(),
256
+ base_model_json=base_model_json,
257
+ )
258
+
259
+ m = eval_regression(yva, p_va)
260
+ fold_metrics.append(m)
261
+ fold_best_iters.append(best_iter)
262
+
263
+ mean_rho = float(np.mean([m["spearman_rho"] for m in fold_metrics]))
264
+ mean_rmse = float(np.mean([m["rmse"] for m in fold_metrics]))
265
+ mean_mae = float(np.mean([m["mae"] for m in fold_metrics]))
266
+ mean_r2 = float(np.mean([m["r2"] for m in fold_metrics]))
267
+ mean_best_iter = float(np.mean(fold_best_iters))
268
+
269
+ trial.set_user_attr("cv_spearman_rho", mean_rho)
270
+ trial.set_user_attr("cv_rmse", mean_rmse)
271
+ trial.set_user_attr("cv_mae", mean_mae)
272
+ trial.set_user_attr("cv_r2", mean_r2)
273
+ trial.set_user_attr("cv_mean_best_iter", mean_best_iter)
274
+
275
+ # maximize Spearman rho (same as your stability workflow :contentReference[oaicite:2]{index=2})
276
+ return mean_rho
277
+
278
+ return objective
279
+
280
+
281
+ def refit_and_save(
282
+ X: np.ndarray,
283
+ y: np.ndarray,
284
+ seqs: np.ndarray,
285
+ out_dir: str,
286
+ best_params: Dict[str, Any],
287
+ n_splits: int,
288
+ device: str,
289
+ base_model_json: Optional[str],
290
+ target_transform: str,
291
+ ):
292
+ os.makedirs(out_dir, exist_ok=True)
293
+
294
+ # Transform target consistently
295
+ if target_transform == "log1p":
296
+ y_used = np.log1p(np.clip(y, a_min=0.0, a_max=None)).astype(np.float32)
297
+ else:
298
+ y_used = y.astype(np.float32)
299
+
300
+ kf = KFold(n_splits=n_splits, shuffle=True, random_state=SEED)
301
+
302
+ # 1) get OOF preds + average best_iteration
303
+ oof_pred = np.zeros_like(y_used, dtype=np.float32)
304
+ best_iters = []
305
+ fold_rows = []
306
+
307
+ for fold, (tr_idx, va_idx) in enumerate(kf.split(X), start=1):
308
+ Xtr, ytr = X[tr_idx], y_used[tr_idx]
309
+ Xva, yva = X[va_idx], y_used[va_idx]
310
+
311
+ _, _, p_va, best_iter = train_xgb_reg(
312
+ Xtr, ytr, Xva, yva, best_params.copy(),
313
+ base_model_json=base_model_json,
314
+ )
315
+ oof_pred[va_idx] = p_va.astype(np.float32)
316
+ best_iters.append(best_iter)
317
+
318
+ m = eval_regression(yva, p_va)
319
+ fold_rows.append({"fold": fold, **m, "best_iter": int(best_iter)})
320
+
321
+ fold_df = pd.DataFrame(fold_rows)
322
+ fold_df.to_csv(os.path.join(out_dir, "cv_fold_metrics.csv"), index=False)
323
+
324
+ cv_metrics = eval_regression(y_used, oof_pred)
325
+ with open(os.path.join(out_dir, "cv_oof_summary.json"), "w") as f:
326
+ json.dump(cv_metrics, f, indent=2)
327
+
328
+ oof_df = pd.DataFrame({
329
+ "sequence": seqs,
330
+ "y_true_used": y_used.astype(float),
331
+ "y_pred_oof": oof_pred.astype(float),
332
+ "residual": (y_used - oof_pred).astype(float),
333
+ })
334
+ oof_df.to_csv(os.path.join(out_dir, "cv_oof_predictions.csv"), index=False)
335
+
336
+ mean_best_iter = int(round(float(np.mean(best_iters))))
337
+ final_rounds = max(mean_best_iter + 1, 10)
338
+
339
+ # 2) train final model on ALL data (no early stopping here; use final_rounds)
340
+ dtrain_all = xgb.DMatrix(X, label=y_used)
341
+
342
+ xgb_model = None
343
+ if base_model_json is not None:
344
+ booster0 = xgb.Booster()
345
+ booster0.load_model(base_model_json)
346
+ xgb_model = booster0
347
+
348
+ final_params = best_params.copy()
349
+ final_params.pop("early_stopping_rounds", None)
350
+ final_params["device"] = "cuda" if (device == "cuda" and torch.cuda.is_available()) else "cpu"
351
+
352
+ booster = xgb.train(
353
+ params=final_params,
354
+ dtrain=dtrain_all,
355
+ num_boost_round=int(final_params.pop("num_boost_round", final_rounds)),
356
+ evals=[],
357
+ verbose_eval=False,
358
+ xgb_model=xgb_model,
359
+ )
360
+
361
+ model_path = os.path.join(out_dir, "best_model_finetuned.json")
362
+ booster.save_model(model_path)
363
+
364
+ with open(os.path.join(out_dir, "final_training_notes.json"), "w") as f:
365
+ json.dump(
366
+ {
367
+ "target_transform": target_transform,
368
+ "final_rounds_used": int(final_rounds),
369
+ "cv_oof_metrics_on_used_target": cv_metrics,
370
+ "model_path": model_path,
371
+ },
372
+ f,
373
+ indent=2,
374
+ )
375
+
376
+ print("=" * 72)
377
+ print("[Final] CV OOF metrics (on transformed target if enabled):")
378
+ print(json.dumps(cv_metrics, indent=2))
379
+ print(f"[Final] Saved finetuned model -> {model_path}")
380
+ print("=" * 72)
381
+
382
+
383
+ def main():
384
+ import argparse
385
+
386
+ parser = argparse.ArgumentParser()
387
+ parser.add_argument("--csv_path", type=str, default="/scratch/pranamlab/tong/data/halflife/wt_halflife_merged_dedup.csv")
388
+ parser.add_argument("--out_dir", type=str, default="/scratch/pranamlab/tong/PeptiVerse/src/halflife/finetune_stability_xgb")
389
+
390
+ # If provided, we will "finetune" by continuing boosting from this model
391
+ parser.add_argument("--base_model_json", type=str, default='/scratch/pranamlab/tong/PeptiVerse/src/stability/xgboost/best_model.json', help="Path to an existing XGBoost .json model to continue training from")
392
+
393
+ # ESM embedding config
394
+ parser.add_argument("--esm_model", type=str, default="facebook/esm2_t33_650M_UR50D")
395
+ parser.add_argument("--esm_batch_size", type=int, default=8)
396
+ parser.add_argument("--esm_max_length", type=int, default=1024)
397
+ parser.add_argument("--no_fp16", action="store_true")
398
+
399
+ # Training config
400
+ parser.add_argument("--n_trials", type=int, default=200)
401
+ parser.add_argument("--n_splits", type=int, default=5)
402
+ parser.add_argument("--device", type=str, default="cuda", choices=["cuda", "cpu"])
403
+ parser.add_argument("--target_transform", type=str, default="none", choices=["none", "log1p"])
404
+
405
+ args = parser.parse_args()
406
+ os.makedirs(args.out_dir, exist_ok=True)
407
+
408
+ # Load data
409
+ df = pd.read_csv(args.csv_path)
410
+ if "sequence" not in df.columns or "half_life_hours" not in df.columns:
411
+ raise ValueError("CSV must contain columns: sequence, half_life_hours")
412
+
413
+ df = df.dropna(subset=["sequence", "half_life_hours"]).copy()
414
+ df["sequence"] = df["sequence"].astype(str).str.strip()
415
+ df = df[df["sequence"].str.len() > 0]
416
+ df = df.drop_duplicates(subset=["sequence"], keep="first").reset_index(drop=True)
417
+
418
+ print(f"[Data] N={len(df)} from {args.csv_path}")
419
+
420
+ # Embeddings (cached)
421
+ embed_cfg = ESMEmbedderConfig(
422
+ model_name=args.esm_model,
423
+ batch_size=args.esm_batch_size,
424
+ max_length=args.esm_max_length,
425
+ fp16=(not args.no_fp16),
426
+ )
427
+ X, y, seqs = load_or_compute_embeddings(df, args.out_dir, embed_cfg, device=args.device)
428
+ print(f"[Embeddings] X={X.shape} (float32)")
429
+
430
+ # Optuna study
431
+ sampler = optuna.samplers.TPESampler(seed=SEED)
432
+ study = optuna.create_study(
433
+ direction="maximize", # like your stability script :contentReference[oaicite:3]{index=3}
434
+ sampler=sampler,
435
+ pruner=optuna.pruners.MedianPruner(),
436
+ )
437
+
438
+ objective = make_cv_objective(
439
+ X=X,
440
+ y=y,
441
+ n_splits=args.n_splits,
442
+ device=args.device,
443
+ base_model_json=args.base_model_json,
444
+ target_transform=args.target_transform,
445
+ )
446
+ study.optimize(objective, n_trials=args.n_trials)
447
+
448
+ # Save trials
449
+ trials_df = study.trials_dataframe()
450
+ trials_df.to_csv(os.path.join(args.out_dir, "study_trials.csv"), index=False)
451
+
452
+ best = study.best_trial
453
+ best_params = dict(best.params)
454
+
455
+ # Build full param dict for refit
456
+ best_xgb_params = {
457
+ "objective": "reg:squarederror",
458
+ "eval_metric": "rmse",
459
+ "lambda": best_params["lambda"],
460
+ "alpha": best_params["alpha"],
461
+ "gamma": best_params["gamma"],
462
+ "max_depth": best_params["max_depth"],
463
+ "min_child_weight": best_params["min_child_weight"],
464
+ "subsample": best_params["subsample"],
465
+ "colsample_bytree": best_params["colsample_bytree"],
466
+ "learning_rate": best_params["learning_rate"],
467
+ "tree_method": "hist",
468
+ "device": "cuda" if (args.device == "cuda" and torch.cuda.is_available()) else "cpu",
469
+ "num_boost_round": best_params["num_boost_round"],
470
+ "early_stopping_rounds": best_params["early_stopping_rounds"],
471
+ }
472
+
473
+ # Summary
474
+ summary = {
475
+ "best_trial_number": int(best.number),
476
+ "best_value_cv_spearman_rho": float(best.value),
477
+ "best_user_attrs": best.user_attrs,
478
+ "best_params": best_params,
479
+ "best_xgb_params_full": best_xgb_params,
480
+ "base_model_json": args.base_model_json,
481
+ "target_transform": args.target_transform,
482
+ "esm_model": args.esm_model,
483
+ "esm_max_length": args.esm_max_length,
484
+ }
485
+ with open(os.path.join(args.out_dir, "optimization_summary.json"), "w") as f:
486
+ json.dump(summary, f, indent=2)
487
+
488
+ print("=" * 72)
489
+ print("[Optuna] Best CV Spearman rho:", float(best.value))
490
+ print("[Optuna] Best params:\n", json.dumps(best_params, indent=2))
491
+ print("=" * 72)
492
+
493
+ # Refit + save final finetuned model + OOF predictions
494
+ refit_and_save(
495
+ X=X,
496
+ y=y,
497
+ seqs=seqs,
498
+ out_dir=args.out_dir,
499
+ best_params=best_xgb_params,
500
+ n_splits=args.n_splits,
501
+ device=args.device,
502
+ base_model_json=args.base_model_json,
503
+ target_transform=args.target_transform,
504
+ )
505
+
506
+
507
+ if __name__ == "__main__":
508
+ main()
training_classifiers/.ipynb_checkpoints/generate_binding_val-checkpoint.py ADDED
@@ -0,0 +1,309 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ # export_val_preds_csv.py
3
+
4
+ import argparse
5
+ from pathlib import Path
6
+ import numpy as np
7
+ import torch
8
+ import torch.nn as nn
9
+ from torch.utils.data import DataLoader
10
+ from datasets import load_from_disk, DatasetDict
11
+
12
+ # -----------------------------
13
+ # Repro / device
14
+ # -----------------------------
15
+ def seed_all(seed=1986):
16
+ import random
17
+ random.seed(seed)
18
+ np.random.seed(seed)
19
+ torch.manual_seed(seed)
20
+ torch.cuda.manual_seed_all(seed)
21
+
22
+ seed_all(1986)
23
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
24
+
25
+
26
+ # -----------------------------
27
+ # Load paired DatasetDict
28
+ # -----------------------------
29
+ def load_split_paired(path: str):
30
+ dd = load_from_disk(path)
31
+ if not isinstance(dd, DatasetDict):
32
+ raise ValueError(f"Expected DatasetDict at {path}")
33
+ if "train" not in dd or "val" not in dd:
34
+ raise ValueError(f"DatasetDict missing train/val at {path}")
35
+ return dd["train"], dd["val"]
36
+
37
+
38
+ # -----------------------------
39
+ # Collate fns (same as yours)
40
+ # -----------------------------
41
+ def collate_pair_pooled(batch):
42
+ Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32)
43
+ Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32)
44
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
45
+ return Pt, Pb, y
46
+
47
+ def collate_pair_unpooled(batch):
48
+ B = len(batch)
49
+ Ht = len(batch[0]["target_embedding"][0])
50
+ Hb = len(batch[0]["binder_embedding"][0])
51
+ Lt_max = max(int(x["target_length"]) for x in batch)
52
+ Lb_max = max(int(x["binder_length"]) for x in batch)
53
+
54
+ Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
55
+ Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
56
+ Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
57
+ Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
58
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
59
+
60
+ for i, x in enumerate(batch):
61
+ t = torch.tensor(x["target_embedding"], dtype=torch.float32)
62
+ b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
63
+ lt, lb = t.shape[0], b.shape[0]
64
+ Pt[i, :lt] = t
65
+ Pb[i, :lb] = b
66
+ Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
67
+ Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
68
+
69
+ return Pt, Mt, Pb, Mb, y
70
+
71
+
72
+ # -----------------------------
73
+ # Models (same as yours)
74
+ # -----------------------------
75
+ class CrossAttnPooled(nn.Module):
76
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
77
+ super().__init__()
78
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
79
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
80
+
81
+ self.layers = nn.ModuleList([])
82
+ for _ in range(n_layers):
83
+ self.layers.append(nn.ModuleDict({
84
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
85
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
86
+ "n1t": nn.LayerNorm(hidden),
87
+ "n2t": nn.LayerNorm(hidden),
88
+ "n1b": nn.LayerNorm(hidden),
89
+ "n2b": nn.LayerNorm(hidden),
90
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
91
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
92
+ }))
93
+
94
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
95
+ self.reg = nn.Linear(hidden, 1)
96
+ self.cls = nn.Linear(hidden, 3)
97
+
98
+ def forward(self, t_vec, b_vec):
99
+ t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
100
+ b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
101
+
102
+ for L in self.layers:
103
+ t_attn, _ = L["attn_tb"](t, b, b)
104
+ t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
105
+ t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
106
+
107
+ b_attn, _ = L["attn_bt"](b, t, t)
108
+ b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
109
+ b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
110
+
111
+ z = torch.cat([t[0], b[0]], dim=-1)
112
+ h = self.shared(z)
113
+ return self.reg(h).squeeze(-1), self.cls(h)
114
+
115
+
116
+ class CrossAttnUnpooled(nn.Module):
117
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
118
+ super().__init__()
119
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
120
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
121
+
122
+ self.layers = nn.ModuleList([])
123
+ for _ in range(n_layers):
124
+ self.layers.append(nn.ModuleDict({
125
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
126
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
127
+ "n1t": nn.LayerNorm(hidden),
128
+ "n2t": nn.LayerNorm(hidden),
129
+ "n1b": nn.LayerNorm(hidden),
130
+ "n2b": nn.LayerNorm(hidden),
131
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
132
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
133
+ }))
134
+
135
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
136
+ self.reg = nn.Linear(hidden, 1)
137
+ self.cls = nn.Linear(hidden, 3)
138
+
139
+ def masked_mean(self, X, M):
140
+ Mf = M.unsqueeze(-1).float()
141
+ denom = Mf.sum(dim=1).clamp(min=1.0)
142
+ return (X * Mf).sum(dim=1) / denom
143
+
144
+ def forward(self, T, Mt, B, Mb):
145
+ T = self.t_proj(T)
146
+ Bx = self.b_proj(B)
147
+
148
+ kp_t = ~Mt
149
+ kp_b = ~Mb
150
+
151
+ for L in self.layers:
152
+ T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
153
+ T = L["n1t"](T + T_attn)
154
+ T = L["n2t"](T + L["fft"](T))
155
+
156
+ B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
157
+ Bx = L["n1b"](Bx + B_attn)
158
+ Bx = L["n2b"](Bx + L["ffb"](Bx))
159
+
160
+ t_pool = self.masked_mean(T, Mt)
161
+ b_pool = self.masked_mean(Bx, Mb)
162
+ z = torch.cat([t_pool, b_pool], dim=-1)
163
+ h = self.shared(z)
164
+ return self.reg(h).squeeze(-1), self.cls(h)
165
+
166
+
167
+ # -----------------------------
168
+ # Helpers
169
+ # -----------------------------
170
+ def softmax_np(logits: np.ndarray) -> np.ndarray:
171
+ x = logits - logits.max(axis=1, keepdims=True)
172
+ ex = np.exp(x)
173
+ return ex / ex.sum(axis=1, keepdims=True)
174
+
175
+ def expected_score_from_probs(probs: np.ndarray, class_centers=(9.5, 8.0, 6.0)) -> np.ndarray:
176
+ centers = np.asarray(class_centers, dtype=np.float32)[None, :] # (1,3)
177
+ return (probs * centers).sum(axis=1)
178
+
179
+ def load_checkpoint(ckpt_path: str, mode: str, train_ds):
180
+ ckpt = torch.load(ckpt_path, map_location="cpu")
181
+ params = ckpt.get("best_params", {})
182
+
183
+ hidden = int(params.get("hidden_dim", 512))
184
+ n_heads = int(params.get("n_heads", 8))
185
+ n_layers = int(params.get("n_layers", 3))
186
+ dropout = float(params.get("dropout", 0.1))
187
+
188
+ if mode == "pooled":
189
+ Ht = len(train_ds[0]["target_embedding"])
190
+ Hb = len(train_ds[0]["binder_embedding"])
191
+ model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
192
+ else:
193
+ Ht = len(train_ds[0]["target_embedding"][0])
194
+ Hb = len(train_ds[0]["binder_embedding"][0])
195
+ model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout)
196
+
197
+ model.load_state_dict(ckpt["state_dict"], strict=True)
198
+ model.to(DEVICE).eval()
199
+ return model
200
+
201
+
202
+ @torch.no_grad()
203
+ def export_val_preds_csv(dataset_path: str, ckpt_path: str, mode: str,
204
+ out_csv: str, batch_size: int, num_workers: int,
205
+ class_centers=(9.5, 8.0, 6.0)):
206
+ train_ds, val_ds = load_split_paired(dataset_path)
207
+ model = load_checkpoint(ckpt_path, mode, train_ds)
208
+
209
+ if mode == "pooled":
210
+ loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
211
+ num_workers=num_workers, pin_memory=True,
212
+ collate_fn=collate_pair_pooled)
213
+ y_all, pred_reg_all, logits_all = [], [], []
214
+ for t, b, y in loader:
215
+ t = t.to(DEVICE, non_blocking=True)
216
+ b = b.to(DEVICE, non_blocking=True)
217
+ pred_reg, logits = model(t, b)
218
+ y_all.append(y.numpy())
219
+ pred_reg_all.append(pred_reg.detach().cpu().numpy())
220
+ logits_all.append(logits.detach().cpu().numpy())
221
+
222
+ else:
223
+ loader = DataLoader(val_ds, batch_size=batch_size, shuffle=False,
224
+ num_workers=num_workers, pin_memory=True,
225
+ collate_fn=collate_pair_unpooled)
226
+ y_all, pred_reg_all, logits_all = [], [], []
227
+ for T, Mt, B, Mb, y in loader:
228
+ T = T.to(DEVICE, non_blocking=True)
229
+ Mt = Mt.to(DEVICE, non_blocking=True)
230
+ B = B.to(DEVICE, non_blocking=True)
231
+ Mb = Mb.to(DEVICE, non_blocking=True)
232
+ pred_reg, logits = model(T, Mt, B, Mb)
233
+ y_all.append(y.numpy())
234
+ pred_reg_all.append(pred_reg.detach().cpu().numpy())
235
+ logits_all.append(logits.detach().cpu().numpy())
236
+
237
+ y_true = np.concatenate(y_all)
238
+ y_pred_reg = np.concatenate(pred_reg_all)
239
+ logits = np.concatenate(logits_all)
240
+
241
+ probs = softmax_np(logits) # (N,3)
242
+ y_pred_cls_score = expected_score_from_probs(probs, class_centers=class_centers)
243
+
244
+ # Build CSV rows
245
+ out = Path(out_csv)
246
+ out.parent.mkdir(parents=True, exist_ok=True)
247
+
248
+ header = [
249
+ "split", "mode",
250
+ "y_true",
251
+ "y_pred_reg",
252
+ "p_high", "p_moderate", "p_low",
253
+ "y_pred_cls_score",
254
+ "center_high", "center_moderate", "center_low",
255
+ ]
256
+
257
+ centers = list(class_centers)
258
+ rows = np.column_stack([
259
+ y_true,
260
+ y_pred_reg,
261
+ probs[:, 0], probs[:, 1], probs[:, 2],
262
+ y_pred_cls_score,
263
+ np.full_like(y_true, centers[0], dtype=np.float32),
264
+ np.full_like(y_true, centers[1], dtype=np.float32),
265
+ np.full_like(y_true, centers[2], dtype=np.float32),
266
+ ])
267
+
268
+ with out.open("w") as f:
269
+ f.write(",".join(header) + "\n")
270
+ for i in range(rows.shape[0]):
271
+ f.write(
272
+ "val," + mode + "," +
273
+ ",".join(f"{rows[i, j]:.8f}" for j in range(rows.shape[1])) +
274
+ "\n"
275
+ )
276
+
277
+ print(f"[Data] Val N={len(y_true)} | mode={mode}")
278
+ print(f"[Saved] {out}")
279
+
280
+
281
+ def main():
282
+ ap = argparse.ArgumentParser()
283
+ ap.add_argument("--dataset_path", required=True, help="Paired DatasetDict path (pair_*)")
284
+ ap.add_argument("--ckpt", required=True, help="Path to best_model.pt")
285
+ ap.add_argument("--mode", choices=["pooled", "unpooled"], required=True)
286
+ ap.add_argument("--out_csv", required=True)
287
+ ap.add_argument("--batch_size", type=int, default=128)
288
+ ap.add_argument("--num_workers", type=int, default=4)
289
+
290
+ # Optional: choose class-centers for expected-score conversion
291
+ ap.add_argument("--center_high", type=float, default=9.5)
292
+ ap.add_argument("--center_moderate", type=float, default=8.0)
293
+ ap.add_argument("--center_low", type=float, default=6.0)
294
+
295
+ args = ap.parse_args()
296
+
297
+ export_val_preds_csv(
298
+ dataset_path=args.dataset_path,
299
+ ckpt_path=args.ckpt,
300
+ mode=args.mode,
301
+ out_csv=args.out_csv,
302
+ batch_size=args.batch_size,
303
+ num_workers=args.num_workers,
304
+ class_centers=(args.center_high, args.center_moderate, args.center_low),
305
+ )
306
+
307
+
308
+ if __name__ == "__main__":
309
+ main()
training_classifiers/.ipynb_checkpoints/peptiverse_filelist-checkpoint.txt ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ./hemolysis/cnn_smiles/optimization_summary.txt
2
+ ./hemolysis/cnn_smiles/pr_curve.png
3
+ ./hemolysis/cnn_smiles/roc_curve.png
4
+ ./hemolysis/cnn_smiles/study_trials.csv
5
+ ./hemolysis/cnn_smiles/train_predictions.csv
6
+ ./hemolysis/cnn_smiles/val_predictions.csv
7
+ ./hemolysis/cnn_wt/optimization_summary.txt
8
+ ./hemolysis/cnn_wt/pr_curve.png
9
+ ./hemolysis/cnn_wt/roc_curve.png
10
+ ./hemolysis/cnn_wt/study_trials.csv
11
+ ./hemolysis/cnn_wt/train_predictions.csv
12
+ ./hemolysis/cnn_wt/val_predictions.csv
13
+ ./hemolysis/enet_gpu/optimization_summary.txt
14
+ ./hemolysis/enet_gpu/pr_curve.png
15
+ ./hemolysis/enet_gpu/roc_curve.png
16
+ ./hemolysis/enet_gpu/study_trials.csv
17
+ ./hemolysis/enet_gpu/train_predictions.csv
18
+ ./hemolysis/enet_gpu/val_predictions.csv
19
+ ./hemolysis/enet_gpu_smiles/optimization_summary.txt
20
+ ./hemolysis/enet_gpu_smiles/pr_curve.png
21
+ ./hemolysis/enet_gpu_smiles/roc_curve.png
22
+ ./hemolysis/enet_gpu_smiles/study_trials.csv
23
+ ./hemolysis/enet_gpu_smiles/train_predictions.csv
24
+ ./hemolysis/enet_gpu_smiles/val_predictions.csv
25
+ ./hemolysis/enet_gpu_wt/optimization_summary.txt
26
+ ./hemolysis/enet_gpu_wt/pr_curve.png
27
+ ./hemolysis/enet_gpu_wt/roc_curve.png
28
+ ./hemolysis/enet_gpu_wt/study_trials.csv
29
+ ./hemolysis/enet_gpu_wt/train_predictions.csv
30
+ ./hemolysis/enet_gpu_wt/val_predictions.csv
31
+ ./hemolysis/mlp_smiles/optimization_summary.txt
32
+ ./hemolysis/mlp_smiles/pr_curve.png
33
+ ./hemolysis/mlp_smiles/roc_curve.png
34
+ ./hemolysis/mlp_smiles/study_trials.csv
35
+ ./hemolysis/mlp_smiles/train_predictions.csv
36
+ ./hemolysis/mlp_smiles/val_predictions.csv
37
+ ./hemolysis/mlp_wt/optimization_summary.txt
38
+ ./hemolysis/mlp_wt/pr_curve.png
39
+ ./hemolysis/mlp_wt/roc_curve.png
40
+ ./hemolysis/mlp_wt/study_trials.csv
41
+ ./hemolysis/mlp_wt/train_predictions.csv
42
+ ./hemolysis/mlp_wt/val_predictions.csv
43
+ ./hemolysis/svm_gpu_wt/optimization_summary.txt
44
+ ./hemolysis/svm_gpu_wt/pr_curve.png
45
+ ./hemolysis/svm_gpu_wt/roc_curve.png
46
+ ./hemolysis/svm_gpu_wt/study_trials.csv
47
+ ./hemolysis/svm_gpu_wt/train_predictions.csv
48
+ ./hemolysis/svm_gpu_wt/val_predictions.csv
49
+ ./hemolysis/transformer_smiles/optimization_summary.txt
50
+ ./hemolysis/transformer_smiles/pr_curve.png
51
+ ./hemolysis/transformer_smiles/roc_curve.png
52
+ ./hemolysis/transformer_smiles/study_trials.csv
53
+ ./hemolysis/transformer_smiles/train_predictions.csv
54
+ ./hemolysis/transformer_smiles/val_predictions.csv
55
+ ./hemolysis/transformer_wt/optimization_summary.txt
56
+ ./hemolysis/transformer_wt/pr_curve.png
57
+ ./hemolysis/transformer_wt/roc_curve.png
58
+ ./hemolysis/transformer_wt/study_trials.csv
59
+ ./hemolysis/transformer_wt/train_predictions.csv
60
+ ./hemolysis/transformer_wt/val_predictions.csv
61
+ ./hemolysis/xgb/optimization_summary.txt
62
+ ./hemolysis/xgb/pr_curve.png
63
+ ./hemolysis/xgb/roc_curve.png
64
+ ./hemolysis/xgb/study_trials.csv
65
+ ./hemolysis/xgb/train_predictions.csv
66
+ ./hemolysis/xgb/val_predictions.csv
67
+ ./hemolysis/xgb_smiles/optimization_summary.txt
68
+ ./hemolysis/xgb_smiles/pr_curve.png
69
+ ./hemolysis/xgb_smiles/roc_curve.png
70
+ ./hemolysis/xgb_smiles/study_trials.csv
71
+ ./hemolysis/xgb_smiles/train_predictions.csv
72
+ ./hemolysis/xgb_smiles/val_predictions.csv
73
+ ./hemolysis/xgb_wt/optimization_summary.txt
74
+ ./hemolysis/xgb_wt/pr_curve.png
75
+ ./hemolysis/xgb_wt/roc_curve.png
76
+ ./hemolysis/xgb_wt/study_trials.csv
77
+ ./hemolysis/xgb_wt/train_predictions.csv
78
+ ./hemolysis/xgb_wt/val_predictions.csv
79
+ ./nf/cnn/optimization_summary.txt
80
+ ./nf/cnn/pr_curve.png
81
+ ./nf/cnn/roc_curve.png
82
+ ./nf/cnn/study_trials.csv
83
+ ./nf/cnn/train_predictions.csv
84
+ ./nf/cnn/val_predictions.csv
85
+ ./nf/cnn_wt/optimization_summary.txt
86
+ ./nf/cnn_wt/pr_curve.png
87
+ ./nf/cnn_wt/roc_curve.png
88
+ ./nf/cnn_wt/study_trials.csv
89
+ ./nf/cnn_wt/train_predictions.csv
90
+ ./nf/cnn_wt/val_predictions.csv
91
+ ./nf/enet_gpu/optimization_summary.txt
92
+ ./nf/enet_gpu/pr_curve.png
93
+ ./nf/enet_gpu/roc_curve.png
94
+ ./nf/enet_gpu/study_trials.csv
95
+ ./nf/enet_gpu/train_predictions.csv
96
+ ./nf/enet_gpu/val_predictions.csv
97
+ ./nf/enet_gpu_smiles/optimization_summary.txt
98
+ ./nf/enet_gpu_smiles/pr_curve.png
99
+ ./nf/enet_gpu_smiles/roc_curve.png
100
+ ./nf/enet_gpu_smiles/study_trials.csv
101
+ ./nf/enet_gpu_smiles/train_predictions.csv
102
+ ./nf/enet_gpu_smiles/val_predictions.csv
103
+ ./nf/enet_gpu_wt/optimization_summary.txt
104
+ ./nf/enet_gpu_wt/pr_curve.png
105
+ ./nf/enet_gpu_wt/roc_curve.png
106
+ ./nf/enet_gpu_wt/study_trials.csv
107
+ ./nf/enet_gpu_wt/train_predictions.csv
108
+ ./nf/enet_gpu_wt/val_predictions.csv
109
+ ./nf/mlp/optimization_summary.txt
110
+ ./nf/mlp/pr_curve.png
111
+ ./nf/mlp/roc_curve.png
112
+ ./nf/mlp/study_trials.csv
113
+ ./nf/mlp/train_predictions.csv
114
+ ./nf/mlp/val_predictions.csv
115
+ ./nf/mlp_wt/optimization_summary.txt
116
+ ./nf/mlp_wt/pr_curve.png
117
+ ./nf/mlp_wt/roc_curve.png
118
+ ./nf/mlp_wt/study_trials.csv
119
+ ./nf/mlp_wt/train_predictions.csv
120
+ ./nf/mlp_wt/val_predictions.csv
121
+ ./nf/svm_gpu/optimization_summary.txt
122
+ ./nf/svm_gpu/pr_curve.png
123
+ ./nf/svm_gpu/roc_curve.png
124
+ ./nf/svm_gpu/study_trials.csv
125
+ ./nf/svm_gpu/train_predictions.csv
126
+ ./nf/svm_gpu/val_predictions.csv
127
+ ./nf/svm_gpu_wt/optimization_summary.txt
128
+ ./nf/svm_gpu_wt/pr_curve.png
129
+ ./nf/svm_gpu_wt/roc_curve.png
130
+ ./nf/svm_gpu_wt/study_trials.csv
131
+ ./nf/svm_gpu_wt/train_predictions.csv
132
+ ./nf/svm_gpu_wt/val_predictions.csv
133
+ ./nf/transformer/optimization_summary.txt
134
+ ./nf/transformer/pr_curve.png
135
+ ./nf/transformer/roc_curve.png
136
+ ./nf/transformer/study_trials.csv
137
+ ./nf/transformer/train_predictions.csv
138
+ ./nf/transformer/val_predictions.csv
139
+ ./nf/transformer_wt/optimization_summary.txt
140
+ ./nf/transformer_wt/pr_curve.png
141
+ ./nf/transformer_wt/roc_curve.png
142
+ ./nf/transformer_wt/study_trials.csv
143
+ ./nf/transformer_wt/train_predictions.csv
144
+ ./nf/transformer_wt/val_predictions.csv
145
+ ./nf/xgb_wt/optimization_summary.txt
146
+ ./nf/xgb_wt/pr_curve.png
147
+ ./nf/xgb_wt/roc_curve.png
148
+ ./nf/xgb_wt/study_trials.csv
149
+ ./nf/xgb_wt/train_predictions.csv
150
+ ./nf/xgb_wt/val_predictions.csv
151
+ ./permeability_caco2/cnn_smiles/optimization_summary.txt
152
+ ./permeability_caco2/cnn_smiles/study_trials.csv
153
+ ./permeability_caco2/cnn_smiles/train_predictions.csv
154
+ ./permeability_caco2/cnn_smiles/val_predictions.csv
155
+ ./permeability_caco2/enet_gpu_smiles/optimization_summary.txt
156
+ ./permeability_caco2/enet_gpu_smiles/study_trials.csv
157
+ ./permeability_caco2/enet_gpu_smiles/train_predictions.csv
158
+ ./permeability_caco2/enet_gpu_smiles/val_predictions.csv
159
+ ./permeability_caco2/mlp_smiles/optimization_summary.txt
160
+ ./permeability_caco2/mlp_smiles/study_trials.csv
161
+ ./permeability_caco2/mlp_smiles/train_predictions.csv
162
+ ./permeability_caco2/mlp_smiles/val_predictions.csv
163
+ ./permeability_caco2/svr_smiles/optimization_summary.txt
164
+ ./permeability_caco2/svr_smiles/study_trials.csv
165
+ ./permeability_caco2/svr_smiles/train_predictions.csv
166
+ ./permeability_caco2/svr_smiles/val_predictions.csv
167
+ ./permeability_caco2/transformer_smiles/optimization_summary.txt
168
+ ./permeability_caco2/transformer_smiles/study_trials.csv
169
+ ./permeability_caco2/transformer_smiles/train_predictions.csv
170
+ ./permeability_caco2/transformer_smiles/val_predictions.csv
171
+ ./permeability_caco2/xgb_reg_smiles/optimization_summary.txt
172
+ ./permeability_caco2/xgb_reg_smiles/study_trials.csv
173
+ ./permeability_caco2/xgb_reg_smiles/train_predictions.csv
174
+ ./permeability_caco2/xgb_reg_smiles/val_predictions.csv
175
+ ./permeability_pampa/cnn_smiles/optimization_summary.txt
176
+ ./permeability_pampa/cnn_smiles/study_trials.csv
177
+ ./permeability_pampa/cnn_smiles/train_predictions.csv
178
+ ./permeability_pampa/cnn_smiles/val_predictions.csv
179
+ ./permeability_pampa/enet_gpu_smiles/optimization_summary.txt
180
+ ./permeability_pampa/enet_gpu_smiles/study_trials.csv
181
+ ./permeability_pampa/enet_gpu_smiles/train_predictions.csv
182
+ ./permeability_pampa/enet_gpu_smiles/val_predictions.csv
183
+ ./permeability_pampa/mlp_smiles/optimization_summary.txt
184
+ ./permeability_pampa/mlp_smiles/study_trials.csv
185
+ ./permeability_pampa/mlp_smiles/train_predictions.csv
186
+ ./permeability_pampa/mlp_smiles/val_predictions.csv
187
+ ./permeability_pampa/transformer_smiles/optimization_summary.txt
188
+ ./permeability_pampa/transformer_smiles/study_trials.csv
189
+ ./permeability_pampa/transformer_smiles/train_predictions.csv
190
+ ./permeability_pampa/transformer_smiles/val_predictions.csv
191
+ ./permeability_pampa/xgb_reg_smiles/optimization_summary.txt
192
+ ./permeability_pampa/xgb_reg_smiles/study_trials.csv
193
+ ./permeability_pampa/xgb_reg_smiles/train_predictions.csv
194
+ ./permeability_pampa/xgb_reg_smiles/val_predictions.csv
195
+ ./solubility/cnn_wt/optimization_summary.txt
196
+ ./solubility/cnn_wt/pr_curve.png
197
+ ./solubility/cnn_wt/roc_curve.png
198
+ ./solubility/cnn_wt/study_trials.csv
199
+ ./solubility/cnn_wt/train_predictions.csv
200
+ ./solubility/cnn_wt/val_predictions.csv
201
+ ./solubility/enet_gpu/optimization_summary.txt
202
+ ./solubility/enet_gpu/pr_curve.png
203
+ ./solubility/enet_gpu/roc_curve.png
204
+ ./solubility/enet_gpu/study_trials.csv
205
+ ./solubility/enet_gpu/train_predictions.csv
206
+ ./solubility/enet_gpu/val_predictions.csv
207
+ ./solubility/mlp_wt/optimization_summary.txt
208
+ ./solubility/mlp_wt/pr_curve.png
209
+ ./solubility/mlp_wt/roc_curve.png
210
+ ./solubility/mlp_wt/study_trials.csv
211
+ ./solubility/mlp_wt/train_predictions.csv
212
+ ./solubility/mlp_wt/val_predictions.csv
213
+ ./solubility/svm_gpu/optimization_summary.txt
214
+ ./solubility/svm_gpu/pr_curve.png
215
+ ./solubility/svm_gpu/roc_curve.png
216
+ ./solubility/svm_gpu/study_trials.csv
217
+ ./solubility/svm_gpu/train_predictions.csv
218
+ ./solubility/svm_gpu/val_predictions.csv
219
+ ./solubility/transformer_wt/optimization_summary.txt
220
+ ./solubility/transformer_wt/pr_curve.png
221
+ ./solubility/transformer_wt/roc_curve.png
222
+ ./solubility/transformer_wt/study_trials.csv
223
+ ./solubility/transformer_wt/train_predictions.csv
224
+ ./solubility/transformer_wt/val_predictions.csv
225
+ ./solubility/xgb/optimization_summary.txt
226
+ ./solubility/xgb/pr_curve.png
227
+ ./solubility/xgb/roc_curve.png
228
+ ./solubility/xgb/study_trials.csv
229
+ ./solubility/xgb/train_predictions.csv
230
+ ./solubility/xgb/val_predictions.csv
231
+ ./binding_affinity/wt_wt_pooled/optuna_trials.csv
232
+ ./binding_affinity/wt_smiles_pooled/optuna_trials.csv
233
+ ./binding_affinity/wt_smiles_unpooled/optuna_trials.csv
234
+ ./binding_affinity/wt_wt_unpooled/optuna_trials.csv
training_classifiers/.ipynb_checkpoints/train_boost-checkpoint.py ADDED
@@ -0,0 +1,417 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import optuna
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+
9
+ from dataclasses import dataclass
10
+ from typing import Dict, Any, Tuple, Optional
11
+
12
+ from datasets import load_from_disk, DatasetDict
13
+ from sklearn.metrics import (
14
+ f1_score, roc_auc_score, average_precision_score,
15
+ precision_recall_curve, roc_curve
16
+ )
17
+ from sklearn.linear_model import LogisticRegression
18
+ from sklearn.ensemble import AdaBoostClassifier
19
+ from sklearn.tree import DecisionTreeClassifier
20
+ from linearboost import LinearBoostClassifier
21
+
22
+ import xgboost as xgb
23
+ from lightning.pytorch import seed_everything
24
+
25
+ seed_everything(1986)
26
+
27
+ # -----------------------------
28
+ # Data loading
29
+ # -----------------------------
30
+ @dataclass
31
+ class SplitData:
32
+ X_train: np.ndarray
33
+ y_train: np.ndarray
34
+ seq_train: Optional[np.ndarray]
35
+ X_val: np.ndarray
36
+ y_val: np.ndarray
37
+ seq_val: Optional[np.ndarray]
38
+
39
+
40
+ def _stack_embeddings(col) -> np.ndarray:
41
+ # HF datasets often store embeddings as list-of-floats per row
42
+ arr = np.asarray(col, dtype=np.float32)
43
+ if arr.ndim != 2:
44
+ arr = np.stack(col).astype(np.float32)
45
+ return arr
46
+
47
+
48
+ def load_split_data(dataset_path: str) -> SplitData:
49
+ ds = load_from_disk(dataset_path)
50
+
51
+ # Case A: DatasetDict with train/val
52
+ if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
53
+ train_ds, val_ds = ds["train"], ds["val"]
54
+ else:
55
+ # Case B: Single dataset with "split" column
56
+ if "split" not in ds.column_names:
57
+ raise ValueError(
58
+ "Dataset must be a DatasetDict(train/val) or have a 'split' column."
59
+ )
60
+ train_ds = ds.filter(lambda x: x["split"] == "train")
61
+ val_ds = ds.filter(lambda x: x["split"] == "val")
62
+
63
+ for required in ["embedding", "label"]:
64
+ if required not in train_ds.column_names:
65
+ raise ValueError(f"Missing column '{required}' in train split.")
66
+ if required not in val_ds.column_names:
67
+ raise ValueError(f"Missing column '{required}' in val split.")
68
+
69
+ X_train = _stack_embeddings(train_ds["embedding"])
70
+ y_train = np.asarray(train_ds["label"], dtype=np.int64)
71
+
72
+ X_val = _stack_embeddings(val_ds["embedding"])
73
+ y_val = np.asarray(val_ds["label"], dtype=np.int64)
74
+
75
+ seq_train = None
76
+ seq_val = None
77
+ if "sequence" in train_ds.column_names:
78
+ seq_train = np.asarray(train_ds["sequence"])
79
+ if "sequence" in val_ds.column_names:
80
+ seq_val = np.asarray(val_ds["sequence"])
81
+
82
+ return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
83
+
84
+
85
+ # -----------------------------
86
+ # Metrics + thresholding
87
+ # -----------------------------
88
+ def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
89
+ """
90
+ Find threshold maximizing F1 on the given set.
91
+ Returns (best_threshold, best_f1).
92
+ """
93
+ precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
94
+ # precision_recall_curve returns thresholds of length n-1
95
+ # compute F1 for those thresholds
96
+ f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
97
+ best_idx = int(np.nanargmax(f1s))
98
+ return float(thresholds[best_idx]), float(f1s[best_idx])
99
+
100
+
101
+ def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
102
+ y_pred = (y_prob >= threshold).astype(int)
103
+ return {
104
+ "f1": float(f1_score(y_true, y_pred)),
105
+ "auc": float(roc_auc_score(y_true, y_prob)),
106
+ "ap": float(average_precision_score(y_true, y_prob)),
107
+ "threshold": float(threshold),
108
+ }
109
+
110
+
111
+ # -----------------------------
112
+ # Model factories
113
+ # -----------------------------
114
+ def train_xgb(
115
+ X_train, y_train, X_val, y_val, params: Dict[str, Any]
116
+ ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
117
+ dtrain = xgb.DMatrix(X_train, label=y_train)
118
+ dval = xgb.DMatrix(X_val, label=y_val)
119
+
120
+ num_boost_round = int(params.pop("num_boost_round"))
121
+ early_stopping_rounds = int(params.pop("early_stopping_rounds"))
122
+
123
+ booster = xgb.train(
124
+ params=params,
125
+ dtrain=dtrain,
126
+ num_boost_round=num_boost_round,
127
+ evals=[(dval, "val")],
128
+ early_stopping_rounds=early_stopping_rounds,
129
+ verbose_eval=False,
130
+ )
131
+
132
+ p_train = booster.predict(dtrain)
133
+ p_val = booster.predict(dval)
134
+ return booster, p_train, p_val
135
+
136
+
137
+ def train_adaboost(
138
+ X_train, y_train, X_val, y_val, params: Dict[str, Any]
139
+ ) -> Tuple[AdaBoostClassifier, np.ndarray, np.ndarray]:
140
+ base_depth = int(params.pop("base_depth"))
141
+ clf = AdaBoostClassifier(
142
+ estimator=DecisionTreeClassifier(max_depth=base_depth),
143
+ n_estimators=int(params["n_estimators"]),
144
+ learning_rate=float(params["learning_rate"]),
145
+ algorithm="SAMME",
146
+ )
147
+ clf.fit(X_train, y_train)
148
+ p_train = clf.predict_proba(X_train)[:, 1]
149
+ p_val = clf.predict_proba(X_val)[:, 1]
150
+ return clf, p_train, p_val
151
+
152
+
153
+ def train_linearboost(X_train, y_train, X_val, y_val, params):
154
+ clf = LinearBoostClassifier(**params)
155
+ clf.fit(X_train, y_train)
156
+ p_train = clf.predict_proba(X_train)[:, 1]
157
+ p_val = clf.predict_proba(X_val)[:, 1]
158
+ return clf, p_train, p_val
159
+
160
+
161
+ def suggest_linearboost_params(trial):
162
+ # Core boosting params
163
+ params = {
164
+ "n_estimators": trial.suggest_int("n_estimators", 50, 800),
165
+ "learning_rate": trial.suggest_float("learning_rate", 0.01, 1.0, log=True),
166
+ "algorithm": trial.suggest_categorical("algorithm", ["SAMME.R", "SAMME"]),
167
+ # Scaling choices from docs (you can expand this list if you want)
168
+ "scaler": trial.suggest_categorical(
169
+ "scaler",
170
+ ["minmax", "standard", "robust", "quantile-uniform", "quantile-normal", "power"]
171
+ ),
172
+ # useful for imbalanced splits
173
+ "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
174
+ # kernel trick
175
+ "kernel": trial.suggest_categorical("kernel", ["linear", "rbf", "poly", "sigmoid"]),
176
+ }
177
+
178
+ # Kernel-specific params (only when relevant)
179
+ if params["kernel"] in ["rbf", "poly"]:
180
+ params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
181
+ else:
182
+ params["gamma"] = None # docs: default treated as 1/n_features for rbf/poly :contentReference[oaicite:5]{index=5}
183
+
184
+ if params["kernel"] == "poly":
185
+ params["degree"] = trial.suggest_int("degree", 2, 6) # docs default=3 :contentReference[oaicite:6]{index=6}
186
+ params["coef0"] = trial.suggest_float("coef0", 0.0, 5.0) # docs default=1 :contentReference[oaicite:7]{index=7}
187
+ else:
188
+ # safe defaults
189
+ params["degree"] = 3
190
+ params["coef0"] = 1.0
191
+
192
+ return params
193
+ # -----------------------------
194
+ # Saving artifacts
195
+ # -----------------------------
196
+ def save_predictions_csv(
197
+ out_dir: str,
198
+ split_name: str,
199
+ y_true: np.ndarray,
200
+ y_prob: np.ndarray,
201
+ threshold: float,
202
+ sequences: Optional[np.ndarray] = None,
203
+ ):
204
+ os.makedirs(out_dir, exist_ok=True)
205
+ df = pd.DataFrame({
206
+ "y_true": y_true.astype(int),
207
+ "y_prob": y_prob.astype(float),
208
+ "y_pred": (y_prob >= threshold).astype(int),
209
+ })
210
+ if sequences is not None:
211
+ df.insert(0, "sequence", sequences)
212
+ df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
213
+
214
+
215
+ def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
216
+ os.makedirs(out_dir, exist_ok=True)
217
+
218
+ # PR
219
+ precision, recall, _ = precision_recall_curve(y_true, y_prob)
220
+ plt.figure()
221
+ plt.plot(recall, precision)
222
+ plt.xlabel("Recall")
223
+ plt.ylabel("Precision")
224
+ plt.title("Precision-Recall Curve")
225
+ plt.tight_layout()
226
+ plt.savefig(os.path.join(out_dir, "pr_curve.png"))
227
+ plt.close()
228
+
229
+ # ROC
230
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
231
+ plt.figure()
232
+ plt.plot(fpr, tpr)
233
+ plt.xlabel("False Positive Rate")
234
+ plt.ylabel("True Positive Rate")
235
+ plt.title("ROC Curve")
236
+ plt.tight_layout()
237
+ plt.savefig(os.path.join(out_dir, "roc_curve.png"))
238
+ plt.close()
239
+
240
+
241
+ # -----------------------------
242
+ # Optuna objectives
243
+ # -----------------------------
244
+ def make_objective(model_name: str, data: SplitData, out_dir: str):
245
+ Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
246
+
247
+ def objective(trial: optuna.Trial) -> float:
248
+ if model_name == "xgb":
249
+ params = {
250
+ "objective": "binary:logistic",
251
+ "eval_metric": "logloss",
252
+ "lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
253
+ "alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
254
+ "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
255
+ "subsample": trial.suggest_float("subsample", 0.5, 1.0),
256
+ "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
257
+ "max_depth": trial.suggest_int("max_depth", 2, 15),
258
+ "min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
259
+ "gamma": trial.suggest_float("gamma", 0.0, 10.0),
260
+ "tree_method": "hist",
261
+ "device": "cuda",
262
+ }
263
+
264
+ # Optional GPU: set env CUDA_VISIBLE_DEVICES externally if you want.
265
+ # If you *know* you want GPU and your xgboost supports it:
266
+ # params["device"] = "cuda"
267
+
268
+ params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
269
+ params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
270
+
271
+ model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
272
+
273
+ elif model_name == "adaboost":
274
+ params = {
275
+ "n_estimators": trial.suggest_int("n_estimators", 50, 800),
276
+ "learning_rate": trial.suggest_float("learning_rate", 1e-3, 2.0, log=True),
277
+ "base_depth": trial.suggest_int("base_depth", 1, 4),
278
+ }
279
+ model, p_tr, p_va = train_adaboost(Xtr, ytr, Xva, yva, params)
280
+
281
+ elif model_name == "linearboost":
282
+ params = suggest_linearboost_params(trial)
283
+ model, p_tr, p_va = train_linearboost(Xtr, ytr, Xva, yva, params)
284
+ else:
285
+ raise ValueError(f"Unknown model_name={model_name}")
286
+
287
+ # Threshold picked on val for fair comparison across models
288
+ thr, f1_at_thr = best_f1_threshold(yva, p_va)
289
+ metrics = eval_binary(yva, p_va, thr)
290
+
291
+ # Track best trial artifacts inside the study directory
292
+ trial.set_user_attr("threshold", thr)
293
+ trial.set_user_attr("auc", metrics["auc"])
294
+ trial.set_user_attr("ap", metrics["ap"])
295
+
296
+ return f1_at_thr
297
+
298
+ return objective
299
+
300
+ # -----------------------------
301
+ # Main runner
302
+ # -----------------------------
303
+ def run_optuna_and_refit(
304
+ dataset_path: str,
305
+ out_dir: str,
306
+ model_name: str,
307
+ n_trials: int = 200,
308
+ ):
309
+ os.makedirs(out_dir, exist_ok=True)
310
+
311
+ data = load_split_data(dataset_path)
312
+ print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
313
+
314
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
315
+ study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
316
+
317
+ # Save trials table
318
+ trials_df = study.trials_dataframe()
319
+ trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
320
+
321
+ best = study.best_trial
322
+ best_params = dict(best.params)
323
+ best_thr = float(best.user_attrs["threshold"])
324
+ best_auc = float(best.user_attrs["auc"])
325
+ best_ap = float(best.user_attrs["ap"])
326
+ best_f1 = float(best.value)
327
+
328
+ # Refit best model on train (same protocol as objective)
329
+ if model_name == "xgb":
330
+ # Reconstruct full param dict
331
+ params = {
332
+ "objective": "binary:logistic",
333
+ "eval_metric": "logloss",
334
+ "lambda": best_params["lambda"],
335
+ "alpha": best_params["alpha"],
336
+ "colsample_bytree": best_params["colsample_bytree"],
337
+ "subsample": best_params["subsample"],
338
+ "learning_rate": best_params["learning_rate"],
339
+ "max_depth": best_params["max_depth"],
340
+ "min_child_weight": best_params["min_child_weight"],
341
+ "gamma": best_params["gamma"],
342
+ "tree_method": "hist",
343
+ "num_boost_round": best_params["num_boost_round"],
344
+ "early_stopping_rounds": best_params["early_stopping_rounds"],
345
+ }
346
+ model, p_tr, p_va = train_xgb(
347
+ data.X_train, data.y_train, data.X_val, data.y_val, params
348
+ )
349
+ model_path = os.path.join(out_dir, "best_model.json")
350
+ model.save_model(model_path)
351
+
352
+ elif model_name == "adaboost":
353
+ params = best_params
354
+ model, p_tr, p_va = train_adaboost(
355
+ data.X_train, data.y_train, data.X_val, data.y_val, params
356
+ )
357
+ model_path = os.path.join(out_dir, "best_model.joblib")
358
+ joblib.dump(model, model_path)
359
+
360
+ elif model_name == "linearboost":
361
+ params = best_params
362
+
363
+ model, p_tr, p_va = train_linearboost(
364
+ data.X_train, data.y_train, data.X_val, data.y_val, params
365
+ )
366
+
367
+ model_path = os.path.join(out_dir, "best_model.joblib")
368
+ joblib.dump(model, model_path)
369
+ else:
370
+ raise ValueError(model_name)
371
+
372
+ # Save predictions CSVs
373
+ save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
374
+ save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
375
+
376
+ # Plots on val
377
+ plot_curves(out_dir, data.y_val, p_va)
378
+
379
+ # Summary
380
+ summary = [
381
+ "=" * 72,
382
+ f"MODEL: {model_name}",
383
+ f"Best trial: {best.number}",
384
+ f"Best F1 (val @ best-threshold): {best_f1:.4f}",
385
+ f"Val AUC: {best_auc:.4f}",
386
+ f"Val AP: {best_ap:.4f}",
387
+ f"Best threshold (picked on val): {best_thr:.4f}",
388
+ f"Model saved to: {model_path}",
389
+ "Best params:",
390
+ json.dumps(best_params, indent=2),
391
+ "=" * 72,
392
+ ]
393
+ with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
394
+ f.write("\n".join(summary))
395
+ print("\n".join(summary))
396
+
397
+
398
+ if __name__ == "__main__":
399
+ # Example usage:
400
+ # dataset_path = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/data/solubility"
401
+ # out_dir = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/src/solubility/xgb"
402
+ # run_optuna_and_refit(dataset_path, out_dir, model_name="xgb", n_trials=200)
403
+
404
+ import argparse
405
+ parser = argparse.ArgumentParser()
406
+ parser.add_argument("--dataset_path", type=str, required=True)
407
+ parser.add_argument("--out_dir", type=str, required=True)
408
+ parser.add_argument("--model", type=str, choices=["xgb", "adaboost", "linearboost"], required=True)
409
+ parser.add_argument("--n_trials", type=int, default=200)
410
+ args = parser.parse_args()
411
+
412
+ run_optuna_and_refit(
413
+ dataset_path=args.dataset_path,
414
+ out_dir=args.out_dir,
415
+ model_name=args.model,
416
+ n_trials=args.n_trials,
417
+ )
training_classifiers/.ipynb_checkpoints/train_ml-checkpoint.py ADDED
@@ -0,0 +1,468 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import optuna
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ from dataclasses import dataclass
9
+ from typing import Dict, Any, Tuple, Optional
10
+ from datasets import load_from_disk, DatasetDict
11
+ from sklearn.metrics import (
12
+ f1_score, roc_auc_score, average_precision_score,
13
+ precision_recall_curve, roc_curve
14
+ )
15
+ from sklearn.linear_model import LogisticRegression
16
+ from sklearn.svm import SVC, LinearSVC
17
+ from sklearn.calibration import CalibratedClassifierCV
18
+ import torch
19
+ import time
20
+ import xgboost as xgb
21
+ from lightning.pytorch import seed_everything
22
+ import cupy as cp
23
+ from cuml.svm import SVC as cuSVC
24
+ from cuml.linear_model import LogisticRegression as cuLogReg
25
+ seed_everything(1986)
26
+
27
+
28
+ def to_gpu(X: np.ndarray):
29
+ if isinstance(X, cp.ndarray):
30
+ return X
31
+ return cp.asarray(X, dtype=cp.float32)
32
+
33
+ def to_cpu(x):
34
+ if isinstance(x, cp.ndarray):
35
+ return cp.asnumpy(x)
36
+ return np.asarray(x
37
+
38
+ @dataclass
39
+ class SplitData:
40
+ X_train: np.ndarray
41
+ y_train: np.ndarray
42
+ seq_train: Optional[np.ndarray]
43
+ X_val: np.ndarray
44
+ y_val: np.ndarray
45
+ seq_val: Optional[np.ndarray]
46
+
47
+
48
+ def _stack_embeddings(col) -> np.ndarray:
49
+ arr = np.asarray(col, dtype=np.float32)
50
+ if arr.ndim != 2:
51
+ arr = np.stack(col).astype(np.float32)
52
+ return arr
53
+
54
+
55
+ def load_split_data(dataset_path: str) -> SplitData:
56
+ ds = load_from_disk(dataset_path)
57
+
58
+ # Case A: DatasetDict with train/val
59
+ if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
60
+ train_ds, val_ds = ds["train"], ds["val"]
61
+ else:
62
+ # Case B: Single dataset with "split" column
63
+ if "split" not in ds.column_names:
64
+ raise ValueError(
65
+ "Dataset must be a DatasetDict(train/val) or have a 'split' column."
66
+ )
67
+ train_ds = ds.filter(lambda x: x["split"] == "train")
68
+ val_ds = ds.filter(lambda x: x["split"] == "val")
69
+
70
+ for required in ["embedding", "label"]:
71
+ if required not in train_ds.column_names:
72
+ raise ValueError(f"Missing column '{required}' in train split.")
73
+ if required not in val_ds.column_names:
74
+ raise ValueError(f"Missing column '{required}' in val split.")
75
+
76
+ X_train = _stack_embeddings(train_ds["embedding"])
77
+ y_train = np.asarray(train_ds["label"], dtype=np.int64)
78
+
79
+ X_val = _stack_embeddings(val_ds["embedding"])
80
+ y_val = np.asarray(val_ds["label"], dtype=np.int64)
81
+
82
+ seq_train = None
83
+ seq_val = None
84
+ if "sequence" in train_ds.column_names:
85
+ seq_train = np.asarray(train_ds["sequence"])
86
+ if "sequence" in val_ds.column_names:
87
+ seq_val = np.asarray(val_ds["sequence"])
88
+
89
+ return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
90
+
91
+
92
+ def best_f1_threshold(y_true: np.ndarray, y_prob: np.ndarray) -> Tuple[float, float]:
93
+ """
94
+ Find threshold maximizing F1 on the given set.
95
+ Returns (best_threshold, best_f1).
96
+ """
97
+ precision, recall, thresholds = precision_recall_curve(y_true, y_prob)
98
+ f1s = (2 * precision[:-1] * recall[:-1]) / (precision[:-1] + recall[:-1] + 1e-12)
99
+ best_idx = int(np.nanargmax(f1s))
100
+ return float(thresholds[best_idx]), float(f1s[best_idx])
101
+
102
+
103
+ def eval_binary(y_true: np.ndarray, y_prob: np.ndarray, threshold: float) -> Dict[str, float]:
104
+ y_pred = (y_prob >= threshold).astype(int)
105
+ return {
106
+ "f1": float(f1_score(y_true, y_pred)),
107
+ "auc": float(roc_auc_score(y_true, y_prob)),
108
+ "ap": float(average_precision_score(y_true, y_prob)),
109
+ "threshold": float(threshold),
110
+ }
111
+
112
+
113
+ # -----------------------------
114
+ # Model
115
+ # -----------------------------
116
+ def train_xgb(
117
+ X_train, y_train, X_val, y_val, params: Dict[str, Any]
118
+ ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
119
+ dtrain = xgb.DMatrix(X_train, label=y_train)
120
+ dval = xgb.DMatrix(X_val, label=y_val)
121
+
122
+ num_boost_round = int(params.pop("num_boost_round"))
123
+ early_stopping_rounds = int(params.pop("early_stopping_rounds"))
124
+
125
+ booster = xgb.train(
126
+ params=params,
127
+ dtrain=dtrain,
128
+ num_boost_round=num_boost_round,
129
+ evals=[(dval, "val")],
130
+ early_stopping_rounds=early_stopping_rounds,
131
+ verbose_eval=False,
132
+ )
133
+
134
+ p_train = booster.predict(dtrain)
135
+ p_val = booster.predict(dval)
136
+ return booster, p_train, p_val
137
+
138
+ def train_cuml_svc(X_train, y_train, X_val, y_val, params):
139
+ Xtr = to_gpu(X_train)
140
+ Xva = to_gpu(X_val)
141
+ ytr = to_gpu(y_train).astype(cp.int32)
142
+
143
+ clf = cuSVC(
144
+ C=float(params["C"]),
145
+ kernel=params["kernel"],
146
+ gamma=params.get("gamma", "scale"),
147
+ class_weight=params.get("class_weight", None),
148
+ probability=bool(params.get("probability", True)),
149
+ random_state=1986,
150
+ max_iter=int(params.get("max_iter", 1000)),
151
+ tol=float(params.get("tol", 1e-4)),
152
+ )
153
+
154
+ clf.fit(Xtr, ytr)
155
+
156
+ p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
157
+ p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
158
+ return clf, p_train, p_val
159
+
160
+ def train_cuml_elastic_net(X_train, y_train, X_val, y_val, params):
161
+ Xtr = to_gpu(X_train)
162
+ Xva = to_gpu(X_val)
163
+ ytr = to_gpu(y_train).astype(cp.int32)
164
+
165
+ clf = cuLogReg(
166
+ penalty="elasticnet",
167
+ C=float(params["C"]),
168
+ l1_ratio=float(params["l1_ratio"]),
169
+ class_weight=params.get("class_weight", None),
170
+ max_iter=int(params.get("max_iter", 1000)),
171
+ tol=float(params.get("tol", 1e-4)),
172
+ solver="qn",
173
+ fit_intercept=True,
174
+ )
175
+ clf.fit(Xtr, ytr)
176
+
177
+ p_train = to_cpu(clf.predict_proba(Xtr)[:, 1])
178
+ p_val = to_cpu(clf.predict_proba(Xva)[:, 1])
179
+ return clf, p_train, p_val
180
+
181
+
182
+ def train_svm(X_train, y_train, X_val, y_val, params):
183
+ """
184
+ Kernel SVM via SVC. CPU only in sklearn.
185
+ probability=True enables predict_proba but is slower.
186
+ """
187
+ clf = SVC(
188
+ C=float(params["C"]),
189
+ kernel=params["kernel"],
190
+ gamma=params.get("gamma", "scale"),
191
+ class_weight=params.get("class_weight", None),
192
+ probability=True,
193
+ random_state=1986,
194
+ )
195
+ clf.fit(X_train, y_train)
196
+ p_train = clf.predict_proba(X_train)[:, 1]
197
+ p_val = clf.predict_proba(X_val)[:, 1]
198
+ return clf, p_train, p_val
199
+
200
+
201
+ def train_linearsvm_calibrated(X_train, y_train, X_val, y_val, params):
202
+ """
203
+ Fast linear SVM (LinearSVC) + probability calibration.
204
+ Usually much faster than SVC on large datasets.
205
+ """
206
+ base = LinearSVC(
207
+ C=float(params["C"]),
208
+ class_weight=params.get("class_weight", None),
209
+ max_iter=int(params.get("max_iter", 5000)),
210
+ random_state=1986,
211
+ )
212
+ # calibration to get probabilities for PR/ROC + thresholding
213
+ clf = CalibratedClassifierCV(base, method="sigmoid", cv=3)
214
+ clf.fit(X_train, y_train)
215
+ p_train = clf.predict_proba(X_train)[:, 1]
216
+ p_val = clf.predict_proba(X_val)[:, 1]
217
+ return clf, p_train, p_val
218
+
219
+ # -----------------------------
220
+ # Saving artifacts
221
+ # -----------------------------
222
+ def save_predictions_csv(
223
+ out_dir: str,
224
+ split_name: str,
225
+ y_true: np.ndarray,
226
+ y_prob: np.ndarray,
227
+ threshold: float,
228
+ sequences: Optional[np.ndarray] = None,
229
+ ):
230
+ os.makedirs(out_dir, exist_ok=True)
231
+ df = pd.DataFrame({
232
+ "y_true": y_true.astype(int),
233
+ "y_prob": y_prob.astype(float),
234
+ "y_pred": (y_prob >= threshold).astype(int),
235
+ })
236
+ if sequences is not None:
237
+ df.insert(0, "sequence", sequences)
238
+ df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
239
+
240
+
241
+ def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
242
+ os.makedirs(out_dir, exist_ok=True)
243
+
244
+ # PR
245
+ precision, recall, _ = precision_recall_curve(y_true, y_prob)
246
+ plt.figure()
247
+ plt.plot(recall, precision)
248
+ plt.xlabel("Recall")
249
+ plt.ylabel("Precision")
250
+ plt.title("Precision-Recall Curve")
251
+ plt.tight_layout()
252
+ plt.savefig(os.path.join(out_dir, "pr_curve.png"))
253
+ plt.close()
254
+
255
+ # ROC
256
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
257
+ plt.figure()
258
+ plt.plot(fpr, tpr)
259
+ plt.xlabel("False Positive Rate")
260
+ plt.ylabel("True Positive Rate")
261
+ plt.title("ROC Curve")
262
+ plt.tight_layout()
263
+ plt.savefig(os.path.join(out_dir, "roc_curve.png"))
264
+ plt.close()
265
+
266
+
267
+ # -----------------------------
268
+ # Optuna objectives
269
+ # -----------------------------
270
+ def make_objective(model_name: str, data: SplitData, out_dir: str):
271
+ Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
272
+
273
+ def objective(trial: optuna.Trial) -> float:
274
+ if model_name == "xgb":
275
+ params = {
276
+ "objective": "binary:logistic",
277
+ "eval_metric": "logloss",
278
+ "lambda": trial.suggest_float("lambda", 1e-8, 50.0, log=True),
279
+ "alpha": trial.suggest_float("alpha", 1e-8, 50.0, log=True),
280
+ "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
281
+ "subsample": trial.suggest_float("subsample", 0.5, 1.0),
282
+ "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
283
+ "max_depth": trial.suggest_int("max_depth", 2, 15),
284
+ "min_child_weight": trial.suggest_int("min_child_weight", 1, 500),
285
+ "gamma": trial.suggest_float("gamma", 0.0, 10.0),
286
+ "tree_method": "hist",
287
+ "device": "cuda",
288
+ }
289
+ params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 1500)
290
+ params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
291
+
292
+ model, p_tr, p_va = train_xgb(Xtr, ytr, Xva, yva, params.copy())
293
+
294
+ elif model_name == "svm":
295
+ svm_kind = trial.suggest_categorical("svm_kind", ["svc", "linear_calibrated"])
296
+
297
+ if svm_kind == "svc":
298
+ params = {
299
+ "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
300
+ "kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
301
+ "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
302
+ }
303
+ if params["kernel"] in ["rbf", "poly", "sigmoid"]:
304
+ params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
305
+ else:
306
+ params["gamma"] = "scale"
307
+
308
+ model, p_tr, p_va = train_svm(Xtr, ytr, Xva, yva, params)
309
+
310
+ else:
311
+ params = {
312
+ "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
313
+ "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
314
+ "max_iter": trial.suggest_int("max_iter", 2000, 20000),
315
+ }
316
+ model, p_tr, p_va = train_linearsvm_calibrated(Xtr, ytr, Xva, yva, params)
317
+ elif model_name == "svm_gpu":
318
+ params = {
319
+ "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
320
+ "kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
321
+ "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
322
+ "probability": True,
323
+ "max_iter": trial.suggest_int("max_iter", 200, 5000),
324
+ "tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
325
+ }
326
+ if params["kernel"] in ["rbf", "poly", "sigmoid"]:
327
+ params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
328
+ else:
329
+ params["gamma"] = "scale"
330
+
331
+ model, p_tr, p_va = train_cuml_svc(Xtr, ytr, Xva, yva, params)
332
+
333
+ elif model_name == "enet_gpu":
334
+ params = {
335
+ "C": trial.suggest_float("C", 1e-4, 1e3, log=True),
336
+ "l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
337
+ "class_weight": trial.suggest_categorical("class_weight", [None, "balanced"]),
338
+ "max_iter": trial.suggest_int("max_iter", 200, 5000),
339
+ "tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
340
+ }
341
+ model, p_tr, p_va = train_cuml_elastic_net(Xtr, ytr, Xva, yva, params)
342
+ else:
343
+ raise ValueError(f"Unknown model_name={model_name}")
344
+
345
+ thr, f1_at_thr = best_f1_threshold(yva, p_va)
346
+ metrics = eval_binary(yva, p_va, thr)
347
+ trial.set_user_attr("threshold", thr)
348
+ trial.set_user_attr("auc", metrics["auc"])
349
+ trial.set_user_attr("ap", metrics["ap"])
350
+ return f1_at_thr
351
+
352
+ return objective
353
+
354
+ # -----------------------------
355
+ # Main
356
+ # -----------------------------
357
+ def run_optuna_and_refit(
358
+ dataset_path: str,
359
+ out_dir: str,
360
+ model_name: str,
361
+ n_trials: int = 200,
362
+ ):
363
+ os.makedirs(out_dir, exist_ok=True)
364
+
365
+ data = load_split_data(dataset_path)
366
+ print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
367
+
368
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
369
+ study.optimize(make_objective(model_name, data, out_dir), n_trials=n_trials)
370
+
371
+ trials_df = study.trials_dataframe()
372
+ trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
373
+
374
+ best = study.best_trial
375
+ best_params = dict(best.params)
376
+ best_thr = float(best.user_attrs["threshold"])
377
+ best_auc = float(best.user_attrs["auc"])
378
+ best_ap = float(best.user_attrs["ap"])
379
+ best_f1 = float(best.value)
380
+
381
+ # Refit best model on train
382
+ if model_name == "xgb":
383
+ params = {
384
+ "objective": "binary:logistic",
385
+ "eval_metric": "logloss",
386
+ "lambda": best_params["lambda"],
387
+ "alpha": best_params["alpha"],
388
+ "colsample_bytree": best_params["colsample_bytree"],
389
+ "subsample": best_params["subsample"],
390
+ "learning_rate": best_params["learning_rate"],
391
+ "max_depth": best_params["max_depth"],
392
+ "min_child_weight": best_params["min_child_weight"],
393
+ "gamma": best_params["gamma"],
394
+ "tree_method": "hist",
395
+ "num_boost_round": best_params["num_boost_round"],
396
+ "early_stopping_rounds": best_params["early_stopping_rounds"],
397
+ }
398
+ model, p_tr, p_va = train_xgb(
399
+ data.X_train, data.y_train, data.X_val, data.y_val, params
400
+ )
401
+ model_path = os.path.join(out_dir, "best_model.json")
402
+ model.save_model(model_path)
403
+
404
+ elif model_name == "svm":
405
+ svm_kind = best_params["svm_kind"]
406
+ if svm_kind == "svc":
407
+ model, p_tr, p_va = train_svm(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
408
+ else:
409
+ model, p_tr, p_va = train_linearsvm_calibrated(data.X_train, data.y_train, data.X_val, data.y_val, best_params)
410
+
411
+ model_path = os.path.join(out_dir, "best_model.joblib")
412
+ joblib.dump(model, model_path)
413
+ elif model_name == "svm_gpu":
414
+ model, p_tr, p_va = train_cuml_svc(
415
+ data.X_train, data.y_train, data.X_val, data.y_val, best_params
416
+ )
417
+ model_path = os.path.join(out_dir, "best_model_cuml_svc.joblib")
418
+ joblib.dump(model, model_path)
419
+
420
+ elif model_name == "enet_gpu":
421
+ model, p_tr, p_va = train_cuml_elastic_net(
422
+ data.X_train, data.y_train, data.X_val, data.y_val, best_params
423
+ )
424
+ model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
425
+ joblib.dump(model, model_path)
426
+ else:
427
+ raise ValueError(model_name)
428
+
429
+ # Save predictions CSVs
430
+ save_predictions_csv(out_dir, "train", data.y_train, p_tr, best_thr, data.seq_train)
431
+ save_predictions_csv(out_dir, "val", data.y_val, p_va, best_thr, data.seq_val)
432
+
433
+ # Plots on val
434
+ plot_curves(out_dir, data.y_val, p_va)
435
+
436
+ summary = [
437
+ "=" * 72,
438
+ f"MODEL: {model_name}",
439
+ f"Best trial: {best.number}",
440
+ f"Best F1 (val @ best-threshold): {best_f1:.4f}",
441
+ f"Val AUC: {best_auc:.4f}",
442
+ f"Val AP: {best_ap:.4f}",
443
+ f"Best threshold (picked on val): {best_thr:.4f}",
444
+ f"Model saved to: {model_path}",
445
+ "Best params:",
446
+ json.dumps(best_params, indent=2),
447
+ "=" * 72,
448
+ ]
449
+ with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
450
+ f.write("\n".join(summary))
451
+ print("\n".join(summary))
452
+
453
+
454
+ if __name__ == "__main__":
455
+ import argparse
456
+ parser = argparse.ArgumentParser()
457
+ parser.add_argument("--dataset_path", type=str, required=True)
458
+ parser.add_argument("--out_dir", type=str, required=True)
459
+ parser.add_argument("--model", type=str, choices=["xgb", "svm_gpu", "enet_gpu"], required=True)
460
+ parser.add_argument("--n_trials", type=int, default=200)
461
+ args = parser.parse_args()
462
+
463
+ run_optuna_and_refit(
464
+ dataset_path=args.dataset_path,
465
+ out_dir=args.out_dir,
466
+ model_name=args.model,
467
+ n_trials=args.n_trials,
468
+ )
training_classifiers/.ipynb_checkpoints/train_ml_regression-checkpoint.py ADDED
@@ -0,0 +1,410 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import json
3
+ import joblib
4
+ import optuna
5
+ import numpy as np
6
+ import pandas as pd
7
+ import matplotlib.pyplot as plt
8
+ from dataclasses import dataclass
9
+ from typing import Dict, Any, Tuple, Optional
10
+ from datasets import load_from_disk, DatasetDict
11
+ from sklearn.preprocessing import StandardScaler
12
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
13
+ from sklearn.svm import SVR
14
+ import xgboost as xgb
15
+ from lightning.pytorch import seed_everything
16
+ import cupy as cp
17
+ from cuml.linear_model import ElasticNet as cuElasticNet
18
+ from scipy.stats import spearmanr
19
+ seed_everything(1986)
20
+
21
+
22
+ # -----------------------------
23
+ # GPU/CPU helpers
24
+ # -----------------------------
25
+ def to_gpu(X: np.ndarray):
26
+ if isinstance(X, cp.ndarray):
27
+ return X
28
+ return cp.asarray(X, dtype=cp.float32)
29
+
30
+ def to_cpu(x):
31
+ if isinstance(x, cp.ndarray):
32
+ return cp.asnumpy(x)
33
+ return np.asarray(x)
34
+
35
+
36
+ # -----------------------------
37
+ # Data loading
38
+ # -----------------------------
39
+ @dataclass
40
+ class SplitData:
41
+ X_train: np.ndarray
42
+ y_train: np.ndarray
43
+ seq_train: Optional[np.ndarray]
44
+ X_val: np.ndarray
45
+ y_val: np.ndarray
46
+ seq_val: Optional[np.ndarray]
47
+
48
+ def _stack_embeddings(col) -> np.ndarray:
49
+ arr = np.asarray(col, dtype=np.float32)
50
+ if arr.ndim != 2:
51
+ arr = np.stack(col).astype(np.float32)
52
+ return arr
53
+
54
+ def load_split_data(dataset_path: str) -> SplitData:
55
+ ds = load_from_disk(dataset_path)
56
+
57
+ if isinstance(ds, DatasetDict) and "train" in ds and "val" in ds:
58
+ train_ds, val_ds = ds["train"], ds["val"]
59
+ else:
60
+ if "split" not in ds.column_names:
61
+ raise ValueError("Dataset must be a DatasetDict(train/val) or have a 'split' column.")
62
+ train_ds = ds.filter(lambda x: x["split"] == "train")
63
+ val_ds = ds.filter(lambda x: x["split"] == "val")
64
+
65
+ for required in ["embedding", "label"]:
66
+ if required not in train_ds.column_names:
67
+ raise ValueError(f"Missing column '{required}' in train split.")
68
+ if required not in val_ds.column_names:
69
+ raise ValueError(f"Missing column '{required}' in val split.")
70
+
71
+ X_train = _stack_embeddings(train_ds["embedding"]).astype(np.float32)
72
+ X_val = _stack_embeddings(val_ds["embedding"]).astype(np.float32)
73
+
74
+ y_train = np.asarray(train_ds["label"], dtype=np.float32)
75
+ y_val = np.asarray(val_ds["label"], dtype=np.float32)
76
+
77
+ seq_train = None
78
+ seq_val = None
79
+ if "sequence" in train_ds.column_names:
80
+ seq_train = np.asarray(train_ds["sequence"])
81
+ if "sequence" in val_ds.column_names:
82
+ seq_val = np.asarray(val_ds["sequence"])
83
+
84
+ return SplitData(X_train, y_train, seq_train, X_val, y_val, seq_val)
85
+
86
+
87
+ # -----------------------------
88
+ # Metrics
89
+ # -----------------------------
90
+ def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
91
+ rho = spearmanr(y_true, y_pred).correlation
92
+ if rho is None or np.isnan(rho):
93
+ return 0.0
94
+ return float(rho)
95
+
96
+ def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
97
+ # RMSE
98
+ try:
99
+ from sklearn.metrics import root_mean_squared_error
100
+ rmse = root_mean_squared_error(y_true, y_pred)
101
+ except Exception:
102
+ rmse = float(np.sqrt(mean_squared_error(y_true, y_pred)))
103
+
104
+ mae = float(mean_absolute_error(y_true, y_pred))
105
+ r2 = float(r2_score(y_true, y_pred))
106
+ rho = float(safe_spearmanr(y_true, y_pred))
107
+ return {"rmse": rmse, "mae": mae, "r2": r2, "spearman_rho": rho}
108
+
109
+
110
+ # -----------------------------
111
+ # Model
112
+ # -----------------------------
113
+ def train_xgb_reg(
114
+ X_train, y_train, X_val, y_val, params: Dict[str, Any]
115
+ ) -> Tuple[xgb.Booster, np.ndarray, np.ndarray]:
116
+ dtrain = xgb.DMatrix(X_train, label=y_train)
117
+ dval = xgb.DMatrix(X_val, label=y_val)
118
+
119
+ num_boost_round = int(params.pop("num_boost_round"))
120
+ early_stopping_rounds = int(params.pop("early_stopping_rounds"))
121
+
122
+ booster = xgb.train(
123
+ params=params,
124
+ dtrain=dtrain,
125
+ num_boost_round=num_boost_round,
126
+ evals=[(dval, "val")],
127
+ early_stopping_rounds=early_stopping_rounds,
128
+ verbose_eval=False,
129
+ )
130
+
131
+ p_train = booster.predict(dtrain)
132
+ p_val = booster.predict(dval)
133
+ return booster, p_train, p_val
134
+
135
+
136
+ def train_cuml_elasticnet_reg(
137
+ X_train, y_train, X_val, y_val, params: Dict[str, Any]
138
+ ):
139
+ Xtr = to_gpu(X_train)
140
+ Xva = to_gpu(X_val)
141
+ ytr = to_gpu(y_train).astype(cp.float32)
142
+
143
+ model = cuElasticNet(
144
+ alpha=float(params["alpha"]),
145
+ l1_ratio=float(params["l1_ratio"]),
146
+ fit_intercept=True,
147
+ max_iter=int(params.get("max_iter", 5000)),
148
+ tol=float(params.get("tol", 1e-4)),
149
+ selection=params.get("selection", "cyclic"),
150
+ )
151
+ model.fit(Xtr, ytr)
152
+
153
+ p_train = to_cpu(model.predict(Xtr))
154
+ p_val = to_cpu(model.predict(Xva))
155
+ return model, p_train, p_val
156
+
157
+
158
+ def train_svr_reg(
159
+ X_train, y_train, X_val, y_val, params: Dict[str, Any]
160
+ ):
161
+ model = SVR(
162
+ C=float(params["C"]),
163
+ epsilon=float(params["epsilon"]),
164
+ kernel=params["kernel"],
165
+ gamma=params.get("gamma", "scale"),
166
+ )
167
+ model.fit(X_train, y_train)
168
+ p_train = model.predict(X_train)
169
+ p_val = model.predict(X_val)
170
+ return model, p_train, p_val
171
+
172
+
173
+ # -----------------------------
174
+ # Saving + plots
175
+ # -----------------------------
176
+ def save_predictions_csv(
177
+ out_dir: str,
178
+ split_name: str,
179
+ y_true: np.ndarray,
180
+ y_pred: np.ndarray,
181
+ sequences: Optional[np.ndarray] = None,
182
+ ):
183
+ os.makedirs(out_dir, exist_ok=True)
184
+ df = pd.DataFrame({
185
+ "y_true": y_true.astype(float),
186
+ "y_pred": y_pred.astype(float),
187
+ "residual": (y_true - y_pred).astype(float),
188
+ })
189
+ if sequences is not None:
190
+ df.insert(0, "sequence", sequences)
191
+ df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
192
+
193
+ def plot_regression_diagnostics(out_dir: str, y_true: np.ndarray, y_pred: np.ndarray):
194
+ os.makedirs(out_dir, exist_ok=True)
195
+
196
+ plt.figure()
197
+ plt.scatter(y_true, y_pred, s=8, alpha=0.5)
198
+ plt.xlabel("y_true")
199
+ plt.ylabel("y_pred")
200
+ plt.title("Predicted vs True")
201
+ plt.tight_layout()
202
+ plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
203
+ plt.close()
204
+
205
+ resid = y_true - y_pred
206
+ plt.figure()
207
+ plt.hist(resid, bins=50)
208
+ plt.xlabel("residual (y_true - y_pred)")
209
+ plt.ylabel("count")
210
+ plt.title("Residual Histogram")
211
+ plt.tight_layout()
212
+ plt.savefig(os.path.join(out_dir, "residual_hist.png"))
213
+ plt.close()
214
+
215
+ plt.figure()
216
+ plt.scatter(y_pred, resid, s=8, alpha=0.5)
217
+ plt.xlabel("y_pred")
218
+ plt.ylabel("residual")
219
+ plt.title("Residuals vs Prediction")
220
+ plt.tight_layout()
221
+ plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
222
+ plt.close()
223
+
224
+
225
+ # -----------------------------
226
+ # Optuna objective (OPTIMIZE SPEARMAN RHO)
227
+ # -----------------------------
228
+ def make_objective(model_name: str, data: SplitData):
229
+ Xtr, ytr, Xva, yva = data.X_train, data.y_train, data.X_val, data.y_val
230
+
231
+ def objective(trial: optuna.Trial) -> float:
232
+ if model_name == "xgb_reg":
233
+ params = {
234
+ "objective": "reg:squarederror",
235
+ "eval_metric": "rmse",
236
+ "lambda": trial.suggest_float("lambda", 1e-10, 100.0, log=True),
237
+ "alpha": trial.suggest_float("alpha", 1e-10, 100.0, log=True),
238
+ "gamma": trial.suggest_float("gamma", 0.0, 10.0),
239
+ "max_depth": trial.suggest_int("max_depth", 2, 16),
240
+ "min_child_weight": trial.suggest_float("min_child_weight", 1e-3, 500.0, log=True),
241
+ "subsample": trial.suggest_float("subsample", 0.5, 1.0),
242
+ "colsample_bytree": trial.suggest_float("colsample_bytree", 0.3, 1.0),
243
+ "learning_rate": trial.suggest_float("learning_rate", 1e-3, 0.3, log=True),
244
+ "tree_method": "hist",
245
+ "device": "cuda",
246
+ }
247
+ params["num_boost_round"] = trial.suggest_int("num_boost_round", 50, 2000)
248
+ params["early_stopping_rounds"] = trial.suggest_int("early_stopping_rounds", 20, 200)
249
+
250
+ model, p_tr, p_va = train_xgb_reg(Xtr, ytr, Xva, yva, params.copy())
251
+
252
+ elif model_name == "enet_gpu":
253
+ params = {
254
+ "alpha": trial.suggest_float("alpha", 1e-8, 10.0, log=True),
255
+ "l1_ratio": trial.suggest_float("l1_ratio", 0.0, 1.0),
256
+ "max_iter": trial.suggest_int("max_iter", 1000, 20000),
257
+ "tol": trial.suggest_float("tol", 1e-6, 1e-2, log=True),
258
+ "selection": trial.suggest_categorical("selection", ["cyclic", "random"]),
259
+ }
260
+ model, p_tr, p_va = train_cuml_elasticnet_reg(Xtr, ytr, Xva, yva, params)
261
+
262
+ elif model_name == "svr":
263
+ params = {
264
+ "kernel": trial.suggest_categorical("kernel", ["rbf", "linear", "poly", "sigmoid"]),
265
+ "C": trial.suggest_float("C", 1e-3, 1e3, log=True),
266
+ "epsilon": trial.suggest_float("epsilon", 1e-4, 1.0, log=True),
267
+ }
268
+ if params["kernel"] in ["rbf", "poly", "sigmoid"]:
269
+ params["gamma"] = trial.suggest_float("gamma", 1e-6, 10.0, log=True)
270
+ else:
271
+ params["gamma"] = "scale"
272
+
273
+ model, p_tr, p_va = train_svr_reg(Xtr, ytr, Xva, yva, params)
274
+
275
+ else:
276
+ raise ValueError(f"Unknown model_name={model_name}")
277
+
278
+ metrics = eval_regression(yva, p_va)
279
+ trial.set_user_attr("spearman_rho", metrics["spearman_rho"])
280
+ trial.set_user_attr("rmse", metrics["rmse"])
281
+ trial.set_user_attr("mae", metrics["mae"])
282
+ trial.set_user_attr("r2", metrics["r2"])
283
+
284
+ # OPTUNA OBJECTIVE = maximize Spearman rho
285
+ return metrics["spearman_rho"]
286
+
287
+ return objective
288
+
289
+
290
+ # -----------------------------
291
+ # Main
292
+ # -----------------------------
293
+ def run_optuna_and_refit(
294
+ dataset_path: str,
295
+ out_dir: str,
296
+ model_name: str,
297
+ n_trials: int = 200,
298
+ standardize_X: bool = True,
299
+ ):
300
+ os.makedirs(out_dir, exist_ok=True)
301
+
302
+ data = load_split_data(dataset_path)
303
+ print(f"[Data] Train: {data.X_train.shape}, Val: {data.X_val.shape}")
304
+
305
+ # Standardize features (SVR + ElasticNet)
306
+ if standardize_X:
307
+ scaler = StandardScaler()
308
+ data.X_train = scaler.fit_transform(data.X_train).astype(np.float32)
309
+ data.X_val = scaler.transform(data.X_val).astype(np.float32)
310
+ joblib.dump(scaler, os.path.join(out_dir, "scaler.joblib"))
311
+ print("[Preprocess] Saved StandardScaler -> scaler.joblib")
312
+
313
+ study = optuna.create_study(
314
+ direction="maximize",
315
+ pruner=optuna.pruners.MedianPruner()
316
+ )
317
+ study.optimize(make_objective(model_name, data), n_trials=n_trials)
318
+
319
+ trials_df = study.trials_dataframe()
320
+ trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
321
+
322
+ best = study.best_trial
323
+ best_params = dict(best.params)
324
+
325
+ best_rho = float(best.user_attrs.get("spearman_rho", best.value))
326
+ best_rmse = float(best.user_attrs.get("rmse", np.nan))
327
+ best_mae = float(best.user_attrs.get("mae", np.nan))
328
+ best_r2 = float(best.user_attrs.get("r2", np.nan))
329
+
330
+ # Refit best model on train
331
+ if model_name == "xgb_reg":
332
+ params = {
333
+ "objective": "reg:squarederror",
334
+ "eval_metric": "rmse",
335
+ "lambda": best_params["lambda"],
336
+ "alpha": best_params["alpha"],
337
+ "gamma": best_params["gamma"],
338
+ "max_depth": best_params["max_depth"],
339
+ "min_child_weight": best_params["min_child_weight"],
340
+ "subsample": best_params["subsample"],
341
+ "colsample_bytree": best_params["colsample_bytree"],
342
+ "learning_rate": best_params["learning_rate"],
343
+ "tree_method": "hist",
344
+ "device": "cuda",
345
+ "num_boost_round": best_params["num_boost_round"],
346
+ "early_stopping_rounds": best_params["early_stopping_rounds"],
347
+ }
348
+ model, p_tr, p_va = train_xgb_reg(
349
+ data.X_train, data.y_train, data.X_val, data.y_val, params
350
+ )
351
+ model_path = os.path.join(out_dir, "best_model.json")
352
+ model.save_model(model_path)
353
+
354
+ elif model_name == "enet_gpu":
355
+ model, p_tr, p_va = train_cuml_elasticnet_reg(
356
+ data.X_train, data.y_train, data.X_val, data.y_val, best_params
357
+ )
358
+ model_path = os.path.join(out_dir, "best_model_cuml_enet.joblib")
359
+ joblib.dump(model, model_path)
360
+
361
+ elif model_name == "svr":
362
+ model, p_tr, p_va = train_svr_reg(
363
+ data.X_train, data.y_train, data.X_val, data.y_val, best_params
364
+ )
365
+ model_path = os.path.join(out_dir, "best_model_svr.joblib")
366
+ joblib.dump(model, model_path)
367
+
368
+ else:
369
+ raise ValueError(model_name)
370
+
371
+ save_predictions_csv(out_dir, "train", data.y_train, p_tr, data.seq_train)
372
+ save_predictions_csv(out_dir, "val", data.y_val, p_va, data.seq_val)
373
+
374
+ plot_regression_diagnostics(out_dir, data.y_val, p_va)
375
+
376
+ summary = [
377
+ "=" * 72,
378
+ f"MODEL: {model_name}",
379
+ f"Best trial: {best.number}",
380
+ f"Val Spearman rho (objective): {best_rho:.6f}",
381
+ f"Val RMSE: {best_rmse:.6f}",
382
+ f"Val MAE: {best_mae:.6f}",
383
+ f"Val R2: {best_r2:.6f}",
384
+ f"Model saved to: {model_path}",
385
+ "Best params:",
386
+ json.dumps(best_params, indent=2),
387
+ "=" * 72,
388
+ ]
389
+ with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
390
+ f.write("\n".join(summary))
391
+ print("\n".join(summary))
392
+
393
+
394
+ if __name__ == "__main__":
395
+ import argparse
396
+ parser = argparse.ArgumentParser()
397
+ parser.add_argument("--dataset_path", type=str, required=True)
398
+ parser.add_argument("--out_dir", type=str, required=True)
399
+ parser.add_argument("--model", type=str, choices=["xgb_reg", "enet_gpu", "svr"], required=True)
400
+ parser.add_argument("--n_trials", type=int, default=200)
401
+ parser.add_argument("--no_standardize", action="store_true", help="Disable StandardScaler on X")
402
+ args = parser.parse_args()
403
+
404
+ run_optuna_and_refit(
405
+ dataset_path=args.dataset_path,
406
+ out_dir=args.out_dir,
407
+ model_name=args.model,
408
+ n_trials=args.n_trials,
409
+ standardize_X=(not args.no_standardize),
410
+ )
training_classifiers/.ipynb_checkpoints/train_nn-checkpoint.py ADDED
@@ -0,0 +1,426 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ from torch.utils.data import DataLoader
4
+ from datasets import load_from_disk, DatasetDict
5
+ from sklearn.metrics import roc_auc_score, precision_recall_curve, f1_score
6
+ import torch.nn as nn
7
+ import optuna
8
+ import os
9
+ from typing import Dict, Any, Tuple, Optional
10
+ import matplotlib.pyplot as plt
11
+ from sklearn.metrics import (
12
+ f1_score, roc_auc_score, average_precision_score,
13
+ precision_recall_curve, roc_curve
14
+ )
15
+ import json
16
+ import joblib
17
+ import pandas as pd
18
+ import time
19
+
20
+ def infer_in_dim_from_unpooled_ds(ds) -> int:
21
+ ex = ds[0]
22
+ # ex["embedding"] is (L, H) list/array
23
+ return int(len(ex["embedding"][0]))
24
+
25
+ def load_split(dataset_path):
26
+ ds = load_from_disk(dataset_path)
27
+
28
+ if isinstance(ds, DatasetDict):
29
+ return ds["train"], ds["val"]
30
+
31
+ raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
32
+
33
+ def collate_unpooled(batch):
34
+ # batch: list of dicts
35
+ lengths = [int(x["length"]) for x in batch]
36
+ Lmax = max(lengths)
37
+ H = len(batch[0]["embedding"][0]) # 1280
38
+
39
+ X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
40
+ M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
41
+ y = torch.tensor([x["label"] for x in batch], dtype=torch.float32)
42
+
43
+ for i, x in enumerate(batch):
44
+ emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L, H)
45
+ L = emb.shape[0]
46
+ X[i, :L] = emb
47
+ if "attention_mask" in x:
48
+ m = torch.tensor(x["attention_mask"], dtype=torch.bool)
49
+ M[i, :L] = m[:L]
50
+ else:
51
+ M[i, :L] = True
52
+
53
+ return X, M, y
54
+
55
+ # ======================== Helper functions =========================================
56
+ def save_predictions_csv(
57
+ out_dir: str,
58
+ split_name: str,
59
+ y_true: np.ndarray,
60
+ y_prob: np.ndarray,
61
+ threshold: float,
62
+ sequences: Optional[np.ndarray] = None,
63
+ ):
64
+ os.makedirs(out_dir, exist_ok=True)
65
+ df = pd.DataFrame({
66
+ "y_true": y_true.astype(int),
67
+ "y_prob": y_prob.astype(float),
68
+ "y_pred": (y_prob >= threshold).astype(int),
69
+ })
70
+ if sequences is not None:
71
+ df.insert(0, "sequence", sequences)
72
+ df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
73
+
74
+
75
+ def plot_curves(out_dir: str, y_true: np.ndarray, y_prob: np.ndarray):
76
+ os.makedirs(out_dir, exist_ok=True)
77
+
78
+ # PR
79
+ precision, recall, _ = precision_recall_curve(y_true, y_prob)
80
+ plt.figure()
81
+ plt.plot(recall, precision)
82
+ plt.xlabel("Recall")
83
+ plt.ylabel("Precision")
84
+ plt.title("Precision-Recall Curve")
85
+ plt.tight_layout()
86
+ plt.savefig(os.path.join(out_dir, "pr_curve.png"))
87
+ plt.close()
88
+
89
+ # ROC
90
+ fpr, tpr, _ = roc_curve(y_true, y_prob)
91
+ plt.figure()
92
+ plt.plot(fpr, tpr)
93
+ plt.xlabel("False Positive Rate")
94
+ plt.ylabel("True Positive Rate")
95
+ plt.title("ROC Curve")
96
+ plt.tight_layout()
97
+ plt.savefig(os.path.join(out_dir, "roc_curve.png"))
98
+ plt.close()
99
+
100
+ # ======================== Shared OPTUNA training scheme =========================================
101
+ def best_f1_threshold(y_true, y_prob):
102
+ p, r, thr = precision_recall_curve(y_true, y_prob)
103
+ f1s = (2*p[:-1]*r[:-1])/(p[:-1]+r[:-1]+1e-12)
104
+ i = int(np.nanargmax(f1s))
105
+ return float(thr[i]), float(f1s[i])
106
+
107
+ @torch.no_grad()
108
+ def eval_probs(model, loader, device):
109
+ model.eval()
110
+ ys, ps = [], []
111
+ for X, M, y in loader:
112
+ X, M = X.to(device), M.to(device)
113
+ logits = model(X, M)
114
+ prob = torch.sigmoid(logits).detach().cpu().numpy()
115
+ ys.append(y.numpy())
116
+ ps.append(prob)
117
+ return np.concatenate(ys), np.concatenate(ps)
118
+
119
+ def train_one_epoch(model, loader, optim, criterion, device):
120
+ model.train()
121
+ for X, M, y in loader:
122
+ X, M, y = X.to(device), M.to(device), y.to(device)
123
+ optim.zero_grad(set_to_none=True)
124
+ logits = model(X, M)
125
+ loss = criterion(logits, y)
126
+ loss.backward()
127
+ optim.step()
128
+
129
+ # ======================== MLP =========================================
130
+ # Still need mean pooling along lengths
131
+ class MaskedMeanPool(nn.Module):
132
+ def forward(self, X, M): # X: (B,L,H), M: (B,L)
133
+ Mf = M.unsqueeze(-1).float()
134
+ denom = Mf.sum(dim=1).clamp(min=1.0)
135
+ return (X * Mf).sum(dim=1) / denom # (B,H)
136
+
137
+ class MLPClassifier(nn.Module):
138
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
139
+ super().__init__()
140
+ self.pool = MaskedMeanPool()
141
+ self.net = nn.Sequential(
142
+ nn.Linear(in_dim, hidden),
143
+ nn.GELU(),
144
+ nn.Dropout(dropout),
145
+ nn.Linear(hidden, 1),
146
+ )
147
+ def forward(self, X, M):
148
+ z = self.pool(X, M)
149
+ return self.net(z).squeeze(-1) # logits
150
+
151
+ # ======================== CNN =========================================
152
+ # Treat 1280 dimensions as channels
153
+ class CNNClassifier(nn.Module):
154
+ def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
155
+ super().__init__()
156
+ blocks = []
157
+ ch = in_ch
158
+ for _ in range(layers):
159
+ blocks += [
160
+ nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
161
+ nn.GELU(),
162
+ nn.Dropout(dropout),
163
+ ]
164
+ ch = c
165
+ self.conv = nn.Sequential(*blocks)
166
+ self.head = nn.Linear(c, 1)
167
+
168
+ def forward(self, X, M):
169
+ # X: (B,L,H) -> (B,H,L)
170
+ Xc = X.transpose(1, 2)
171
+ Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
172
+
173
+ # masked mean pool over L
174
+ Mf = M.unsqueeze(-1).float()
175
+ denom = Mf.sum(dim=1).clamp(min=1.0)
176
+ pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
177
+ return self.head(pooled).squeeze(-1)
178
+
179
+ # ========================== Transformer ====================================
180
+ class TransformerClassifier(nn.Module):
181
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
182
+ super().__init__()
183
+ self.proj = nn.Linear(in_dim, d_model)
184
+ enc_layer = nn.TransformerEncoderLayer(
185
+ d_model=d_model, nhead=nhead, dim_feedforward=ff,
186
+ dropout=dropout, batch_first=True, activation="gelu"
187
+ )
188
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
189
+ self.head = nn.Linear(d_model, 1)
190
+
191
+ def forward(self, X, M):
192
+ # src_key_padding_mask: True = pad positions
193
+ pad_mask = ~M
194
+ Z = self.proj(X) # (B,L,d)
195
+ Z = self.enc(Z, src_key_padding_mask=pad_mask) # (B,L,d)
196
+
197
+ Mf = M.unsqueeze(-1).float()
198
+ denom = Mf.sum(dim=1).clamp(min=1.0)
199
+ pooled = (Z * Mf).sum(dim=1) / denom
200
+ return self.head(pooled).squeeze(-1)
201
+
202
+ # ========================== OPTUNA ====================================
203
+
204
+ def objective_nn(trial, model_name, train_ds, val_ds, device="cuda:0"):
205
+ # hyperparams shared
206
+ lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
207
+ wd = trial.suggest_float("weight_decay", 1e-8, 1e-2, log=True)
208
+ dropout = trial.suggest_float("dropout", 0.0, 0.5)
209
+ batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
210
+
211
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
212
+ collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
213
+ val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
214
+ collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
215
+
216
+ in_dim = infer_in_dim_from_unpooled_ds(train_ds)
217
+
218
+ if model_name == "mlp":
219
+ hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
220
+ model = MLPClassifier(in_dim=in_dim, hidden=hidden, dropout=dropout)
221
+ elif model_name == "cnn":
222
+ c = trial.suggest_categorical("channels", [128, 256, 512])
223
+ k = trial.suggest_categorical("kernel", [3, 5, 7])
224
+ layers = trial.suggest_int("layers", 1, 4)
225
+ model = CNNClassifier(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
226
+ elif model_name == "transformer":
227
+ d = trial.suggest_categorical("d_model", [128, 256, 384])
228
+ nhead = trial.suggest_categorical("nhead", [4, 8])
229
+ layers = trial.suggest_int("layers", 1, 4)
230
+ ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
231
+ model = TransformerClassifier(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
232
+ else:
233
+ raise ValueError(model_name)
234
+
235
+ model = model.to(device)
236
+
237
+ # class imbalance handling
238
+ ytr = np.asarray(train_ds["label"], dtype=np.int64)
239
+ pos = ytr.sum()
240
+ neg = len(ytr) - pos
241
+ pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
242
+ criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
243
+
244
+ optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
245
+
246
+ best_f1 = -1.0
247
+ patience = 8
248
+ bad = 0
249
+
250
+ for epoch in range(1, 51):
251
+ train_one_epoch(model, train_loader, optim, criterion, device)
252
+
253
+ y_true, y_prob = eval_probs(model, val_loader, device)
254
+ auc = roc_auc_score(y_true, y_prob)
255
+
256
+ thr, f1 = best_f1_threshold(y_true, y_prob)
257
+
258
+ trial.set_user_attr("val_auc", float(auc))
259
+ trial.set_user_attr("val_f1", float(f1))
260
+ trial.set_user_attr("val_thr", float(thr))
261
+
262
+ # prune
263
+ trial.report(f1, epoch)
264
+ if trial.should_prune():
265
+ raise optuna.TrialPruned()
266
+
267
+ if f1 > best_f1 + 1e-4:
268
+ best_f1 = f1
269
+ bad = 0
270
+ else:
271
+ bad += 1
272
+ if bad >= patience:
273
+ break
274
+
275
+ return best_f1
276
+
277
+ def run_optuna_and_refit_nn(dataset_path: str, out_dir: str, model_name: str, n_trials: int = 50, device="cuda:0"):
278
+ os.makedirs(out_dir, exist_ok=True)
279
+
280
+ train_ds, val_ds = load_split(dataset_path)
281
+ print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
282
+
283
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
284
+ study.optimize(lambda trial: objective_nn(trial, model_name, train_ds, val_ds, device=device), n_trials=n_trials)
285
+
286
+ trials_df = study.trials_dataframe()
287
+ trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
288
+
289
+ best = study.best_trial
290
+ best_params = dict(best.params)
291
+ best_f1_optuna = float(best.value)
292
+ best_auc_optuna = float(best.user_attrs.get("val_auc", np.nan))
293
+ best_thr = float(best.user_attrs.get("val_thr", 0.5))
294
+
295
+ in_dim = infer_in_dim_from_unpooled_ds(train_ds)
296
+
297
+ # --- Refit best model ---
298
+ batch_size = int(best_params.get("batch_size", 32))
299
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
300
+ collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
301
+ val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
302
+ collate_fn=collate_unpooled, num_workers=4, pin_memory=True)
303
+
304
+ # Rebuild
305
+ dropout = float(best_params.get("dropout", 0.1))
306
+ if model_name == "mlp":
307
+ model = MLPClassifier(
308
+ in_dim=in_dim,
309
+ hidden=int(best_params["hidden"]),
310
+ dropout=dropout,
311
+ )
312
+
313
+ elif model_name == "cnn":
314
+ model = CNNClassifier(
315
+ in_ch=in_dim,
316
+ c=int(best_params["channels"]),
317
+ k=int(best_params["kernel"]),
318
+ layers=int(best_params["layers"]),
319
+ dropout=dropout,
320
+ )
321
+
322
+ elif model_name == "transformer":
323
+ model = TransformerClassifier(
324
+ in_dim=in_dim,
325
+ d_model=int(best_params["d_model"]),
326
+ nhead=int(best_params["nhead"]),
327
+ layers=int(best_params["layers"]),
328
+ ff=int(best_params["ff"]),
329
+ dropout=dropout,
330
+ )
331
+ else:
332
+ raise ValueError(model_name)
333
+
334
+ model = model.to(device)
335
+
336
+ # loss + optimizer
337
+ ytr = np.asarray(train_ds["label"], dtype=np.int64)
338
+ pos = ytr.sum()
339
+ neg = len(ytr) - pos
340
+ pos_weight = torch.tensor([neg / max(pos, 1)], device=device, dtype=torch.float32)
341
+ criterion = torch.nn.BCEWithLogitsLoss(pos_weight=pos_weight)
342
+
343
+ lr = float(best_params["lr"])
344
+ wd = float(best_params["weight_decay"])
345
+ optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
346
+
347
+ # train longer with early stopping on AUC
348
+ best_f1_seen, bad, patience = -1.0, 0, 12
349
+ best_state = None
350
+ best_thr_seen = 0.5
351
+ best_auc_seen = -1.0
352
+
353
+ for epoch in range(1, 151):
354
+ train_one_epoch(model, train_loader, optim, criterion, device)
355
+
356
+ y_true, y_prob = eval_probs(model, val_loader, device)
357
+ auc = roc_auc_score(y_true, y_prob)
358
+ thr, f1 = best_f1_threshold(y_true, y_prob)
359
+
360
+ if f1 > best_f1_seen + 1e-4:
361
+ best_f1_seen = f1
362
+ best_thr_seen = thr
363
+ best_auc_seen = auc
364
+ bad = 0
365
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
366
+ else:
367
+ bad += 1
368
+ if bad >= patience:
369
+ break
370
+
371
+ if best_state is not None:
372
+ model.load_state_dict(best_state)
373
+
374
+ # final preds + threshold picked on val
375
+ y_true_val, y_prob_val = eval_probs(model, val_loader, device)
376
+ best_thr_final, best_f1_final = best_f1_threshold(y_true_val, y_prob_val)
377
+
378
+ # save model
379
+ model_path = os.path.join(out_dir, "best_model.pt")
380
+ torch.save({"state_dict": model.state_dict(), "best_params": best_params}, model_path)
381
+
382
+ # train preds
383
+ y_true_tr, y_prob_tr = eval_probs(model, DataLoader(train_ds, batch_size=64, shuffle=False,
384
+ collate_fn=collate_unpooled, num_workers=4, pin_memory=True), device)
385
+
386
+ save_predictions_csv(out_dir, "train", y_true_tr, y_prob_tr, best_thr_final,
387
+ sequences=np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None)
388
+ save_predictions_csv(out_dir, "val", y_true_val, y_prob_val, best_thr_final,
389
+ sequences=np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None)
390
+
391
+ plot_curves(out_dir, y_true_val, y_prob_val)
392
+
393
+ summary = [
394
+ "=" * 72,
395
+ f"MODEL: {model_name}",
396
+
397
+ # Optuna results (objective = F1)
398
+ f"Best Optuna F1 (objective): {best_f1_optuna:.4f}",
399
+ f"Best Optuna AUC (val, recorded): {best_auc_optuna:.4f}",
400
+ f"Best Optuna threshold (val): {best_thr:.4f}",
401
+
402
+ # Refit results
403
+ f"Refit best AUC (val): {best_auc_seen:.4f}",
404
+ f"Refit best F1@thr (val): {best_f1_final:.4f} at thr={best_thr_final:.4f}",
405
+
406
+ "Best params:",
407
+ json.dumps(best_params, indent=2),
408
+ f"Saved model: {model_path}",
409
+ "=" * 72,
410
+ ]
411
+
412
+ with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
413
+ f.write("\n".join(summary))
414
+ print("\n".join(summary))
415
+
416
+ if __name__ == "__main__":
417
+ import argparse
418
+ parser = argparse.ArgumentParser()
419
+ parser.add_argument("--dataset_path", type=str, required=True)
420
+ parser.add_argument("--out_dir", type=str, required=True)
421
+ parser.add_argument("--model", type=str, choices=["mlp", "cnn", "transformer"], required=True)
422
+ parser.add_argument("--n_trials", type=int, default=50)
423
+ args = parser.parse_args()
424
+
425
+ if args.model in ["mlp", "cnn", "transformer"]:
426
+ run_optuna_and_refit_nn(args.dataset_path, args.out_dir, args.model, args.n_trials, device="cuda:0")
training_classifiers/.ipynb_checkpoints/train_nn_regression-checkpoint.py ADDED
@@ -0,0 +1,420 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json, time
2
+ import numpy as np
3
+ import pandas as pd
4
+ import matplotlib.pyplot as plt
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.utils.data import DataLoader
9
+ from datasets import load_from_disk, DatasetDict
10
+ import optuna
11
+ from dataclasses import dataclass
12
+ from typing import Dict, Any, Tuple, Optional
13
+ from sklearn.metrics import mean_squared_error, mean_absolute_error, r2_score
14
+ from scipy.stats import spearmanr
15
+ from torch.cuda.amp import autocast
16
+ from torch.cuda.amp import autocast, GradScaler
17
+ scaler = GradScaler(enabled=torch.cuda.is_available())
18
+ from lightning.pytorch import seed_everything
19
+ seed_everything(1986)
20
+
21
+
22
+ def load_split(dataset_path):
23
+ ds = load_from_disk(dataset_path)
24
+ if isinstance(ds, DatasetDict):
25
+ return ds["train"], ds["val"]
26
+ raise ValueError("Expected DatasetDict with 'train' and 'val' splits")
27
+
28
+ def collate_unpooled_reg(batch):
29
+ lengths = [int(x["length"]) for x in batch]
30
+ Lmax = max(lengths)
31
+ H = len(batch[0]["embedding"][0])
32
+
33
+ X = torch.zeros(len(batch), Lmax, H, dtype=torch.float32)
34
+ M = torch.zeros(len(batch), Lmax, dtype=torch.bool)
35
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
36
+
37
+ for i, x in enumerate(batch):
38
+ emb = torch.tensor(x["embedding"], dtype=torch.float32) # (L,H)
39
+ L = emb.shape[0]
40
+ X[i, :L] = emb
41
+ if "attention_mask" in x:
42
+ m = torch.tensor(x["attention_mask"], dtype=torch.bool)
43
+ M[i, :L] = m[:L]
44
+ else:
45
+ M[i, :L] = True
46
+ return X, M, y
47
+
48
+ def infer_in_dim(ds) -> int:
49
+ ex = ds[0]
50
+ return int(len(ex["embedding"][0]))
51
+
52
+ # ============================
53
+ # Metrics
54
+ # ============================
55
+ def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
56
+ rho = spearmanr(y_true, y_pred).correlation
57
+ if rho is None or np.isnan(rho):
58
+ return 0.0
59
+ return float(rho)
60
+
61
+ def eval_regression(y_true: np.ndarray, y_pred: np.ndarray) -> Dict[str, float]:
62
+ # ---- RMSE ----
63
+ try:
64
+ from sklearn.metrics import root_mean_squared_error
65
+ rmse = root_mean_squared_error(y_true, y_pred)
66
+ except Exception:
67
+ mse = mean_squared_error(y_true, y_pred)
68
+ rmse = float(np.sqrt(mse))
69
+
70
+ mae = float(mean_absolute_error(y_true, y_pred))
71
+ r2 = float(r2_score(y_true, y_pred))
72
+ rho = float(safe_spearmanr(y_true, y_pred))
73
+ return {"rmse": float(rmse), "mae": mae, "r2": r2, "spearman_rho": rho}
74
+
75
+
76
+ # ============================
77
+ # Models
78
+ # ============================
79
+ class MaskedMeanPool(nn.Module):
80
+ def forward(self, X, M):
81
+ Mf = M.unsqueeze(-1).float()
82
+ denom = Mf.sum(dim=1).clamp(min=1.0)
83
+ return (X * Mf).sum(dim=1) / denom
84
+
85
+ class MLPRegressor(nn.Module):
86
+ def __init__(self, in_dim, hidden=512, dropout=0.1):
87
+ super().__init__()
88
+ self.pool = MaskedMeanPool()
89
+ self.net = nn.Sequential(
90
+ nn.Linear(in_dim, hidden),
91
+ nn.GELU(),
92
+ nn.Dropout(dropout),
93
+ nn.Linear(hidden, 1),
94
+ )
95
+ def forward(self, X, M):
96
+ z = self.pool(X, M)
97
+ return self.net(z).squeeze(-1) # y_pred
98
+
99
+ class CNNRegressor(nn.Module):
100
+ def __init__(self, in_ch, c=256, k=5, layers=2, dropout=0.1):
101
+ super().__init__()
102
+ blocks = []
103
+ ch = in_ch
104
+ for _ in range(layers):
105
+ blocks += [
106
+ nn.Conv1d(ch, c, kernel_size=k, padding=k//2),
107
+ nn.GELU(),
108
+ nn.Dropout(dropout),
109
+ ]
110
+ ch = c
111
+ self.conv = nn.Sequential(*blocks)
112
+ self.head = nn.Linear(c, 1)
113
+
114
+ def forward(self, X, M):
115
+ Xc = X.transpose(1, 2) # (B,H,L)
116
+ Y = self.conv(Xc).transpose(1, 2) # (B,L,C)
117
+ Mf = M.unsqueeze(-1).float()
118
+ denom = Mf.sum(dim=1).clamp(min=1.0)
119
+ pooled = (Y * Mf).sum(dim=1) / denom # (B,C)
120
+ return self.head(pooled).squeeze(-1)
121
+
122
+ class TransformerRegressor(nn.Module):
123
+ def __init__(self, in_dim, d_model=256, nhead=8, layers=2, ff=512, dropout=0.1):
124
+ super().__init__()
125
+ self.proj = nn.Linear(in_dim, d_model)
126
+ enc_layer = nn.TransformerEncoderLayer(
127
+ d_model=d_model, nhead=nhead, dim_feedforward=ff,
128
+ dropout=dropout, batch_first=True, activation="gelu"
129
+ )
130
+ self.enc = nn.TransformerEncoder(enc_layer, num_layers=layers)
131
+ self.head = nn.Linear(d_model, 1)
132
+
133
+ def forward(self, X, M):
134
+ pad_mask = ~M
135
+ Z = self.proj(X)
136
+ Z = self.enc(Z, src_key_padding_mask=pad_mask)
137
+ Mf = M.unsqueeze(-1).float()
138
+ denom = Mf.sum(dim=1).clamp(min=1.0)
139
+ pooled = (Z * Mf).sum(dim=1) / denom
140
+ return self.head(pooled).squeeze(-1)
141
+
142
+ # ============================
143
+ # Train / eval
144
+ # ============================
145
+ @torch.no_grad()
146
+ def eval_preds(model, loader, device):
147
+ model.eval()
148
+ ys, ps = [], []
149
+ for X, M, y in loader:
150
+ X, M = X.to(device), M.to(device)
151
+ pred = model(X, M).detach().cpu().numpy()
152
+ ys.append(y.numpy())
153
+ ps.append(pred)
154
+ return np.concatenate(ys), np.concatenate(ps)
155
+
156
+ def train_one_epoch_reg(model, loader, optim, criterion, device):
157
+ model.train()
158
+ for X, M, y in loader:
159
+ X, M, y = X.to(device), M.to(device), y.to(device)
160
+ optim.zero_grad(set_to_none=True)
161
+ with autocast(enabled=torch.cuda.is_available()):
162
+ pred = model(X, M)
163
+ loss = criterion(pred, y)
164
+ scaler.scale(loss).backward()
165
+ scaler.step(optim)
166
+ scaler.update()
167
+
168
+ # ============================
169
+ # Saving + plots
170
+ # ============================
171
+ def save_predictions_csv(out_dir, split_name, y_true, y_pred, sequences=None):
172
+ os.makedirs(out_dir, exist_ok=True)
173
+ df = pd.DataFrame({
174
+ "y_true": y_true.astype(float),
175
+ "y_pred": y_pred.astype(float),
176
+ "residual": (y_true - y_pred).astype(float),
177
+ })
178
+ if sequences is not None:
179
+ df.insert(0, "sequence", sequences)
180
+ df.to_csv(os.path.join(out_dir, f"{split_name}_predictions.csv"), index=False)
181
+
182
+ def plot_regression_diagnostics(out_dir, y_true, y_pred):
183
+ os.makedirs(out_dir, exist_ok=True)
184
+
185
+ plt.figure()
186
+ plt.scatter(y_true, y_pred, s=8, alpha=0.5)
187
+ plt.xlabel("y_true"); plt.ylabel("y_pred")
188
+ plt.title("Predicted vs True")
189
+ plt.tight_layout()
190
+ plt.savefig(os.path.join(out_dir, "pred_vs_true.png"))
191
+ plt.close()
192
+
193
+ resid = y_true - y_pred
194
+ plt.figure()
195
+ plt.hist(resid, bins=50)
196
+ plt.xlabel("residual (y_true - y_pred)"); plt.ylabel("count")
197
+ plt.title("Residual Histogram")
198
+ plt.tight_layout()
199
+ plt.savefig(os.path.join(out_dir, "residual_hist.png"))
200
+ plt.close()
201
+
202
+ plt.figure()
203
+ plt.scatter(y_pred, resid, s=8, alpha=0.5)
204
+ plt.xlabel("y_pred"); plt.ylabel("residual")
205
+ plt.title("Residuals vs Prediction")
206
+ plt.tight_layout()
207
+ plt.savefig(os.path.join(out_dir, "residual_vs_pred.png"))
208
+ plt.close()
209
+
210
+ # ============================
211
+ # Optuna objective
212
+ # ============================
213
+ def score_from_metrics(metrics: Dict[str, float], objective: str) -> float:
214
+ if objective == "spearman":
215
+ return metrics["spearman_rho"]
216
+ if objective == "r2":
217
+ return metrics["r2"]
218
+ if objective == "neg_rmse":
219
+ return -metrics["rmse"]
220
+ raise ValueError(f"Unknown objective={objective}")
221
+
222
+ def objective_nn_reg(trial, model_name, train_ds, val_ds, device="cuda:0", objective="spearman"):
223
+ lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
224
+ wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
225
+ dropout = trial.suggest_float("dropout", 0.0, 0.5)
226
+ batch_size = trial.suggest_categorical("batch_size", [16, 32, 64])
227
+
228
+ in_dim = infer_in_dim(train_ds)
229
+
230
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
231
+ collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
232
+ val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
233
+ collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
234
+
235
+ if model_name == "mlp":
236
+ hidden = trial.suggest_categorical("hidden", [256, 512, 1024, 2048])
237
+ model = MLPRegressor(in_dim=in_dim, hidden=hidden, dropout=dropout)
238
+ elif model_name == "cnn":
239
+ c = trial.suggest_categorical("channels", [128, 256, 512])
240
+ k = trial.suggest_categorical("kernel", [3, 5, 7])
241
+ layers = trial.suggest_int("layers", 1, 4)
242
+ model = CNNRegressor(in_ch=in_dim, c=c, k=k, layers=layers, dropout=dropout)
243
+ elif model_name == "transformer":
244
+ d = trial.suggest_categorical("d_model", [128, 256, 384])
245
+ nhead = trial.suggest_categorical("nhead", [4, 8])
246
+ layers = trial.suggest_int("layers", 1, 4)
247
+ ff = trial.suggest_categorical("ff", [256, 512, 1024, 1536])
248
+ model = TransformerRegressor(in_dim=in_dim, d_model=d, nhead=nhead, layers=layers, ff=ff, dropout=dropout)
249
+ else:
250
+ raise ValueError(model_name)
251
+
252
+ model = model.to(device)
253
+
254
+ loss_name = trial.suggest_categorical("loss", ["mse", "huber"])
255
+ if loss_name == "mse":
256
+ criterion = nn.MSELoss()
257
+ else:
258
+ delta = trial.suggest_float("huber_delta", 0.5, 5.0, log=True)
259
+ criterion = nn.HuberLoss(delta=delta)
260
+
261
+ optim = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
262
+
263
+ best_score = -1e18
264
+ patience = 10
265
+ bad = 0
266
+
267
+ for epoch in range(1, 61):
268
+ train_one_epoch_reg(model, train_loader, optim, criterion, device)
269
+
270
+ y_true, y_pred = eval_preds(model, val_loader, device)
271
+ metrics = eval_regression(y_true, y_pred)
272
+ score = score_from_metrics(metrics, objective)
273
+
274
+ # log attrs
275
+ for k, v in metrics.items():
276
+ trial.set_user_attr(f"val_{k}", float(v))
277
+
278
+ trial.report(score, epoch)
279
+ if trial.should_prune():
280
+ raise optuna.TrialPruned()
281
+
282
+ if score > best_score + 1e-6:
283
+ best_score = score
284
+ bad = 0
285
+ else:
286
+ bad += 1
287
+ if bad >= patience:
288
+ break
289
+
290
+ return float(best_score)
291
+
292
+ # ============================
293
+ # Main runner
294
+ # ============================
295
+ def run_optuna_and_refit_nn_reg(dataset_path, out_dir, model_name, n_trials=80, device="cuda:0",
296
+ objective="spearman"):
297
+ os.makedirs(out_dir, exist_ok=True)
298
+
299
+ train_ds, val_ds = load_split(dataset_path)
300
+ print(f"[Data] Train: {len(train_ds)}, Val: {len(val_ds)}")
301
+
302
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
303
+ study.optimize(lambda t: objective_nn_reg(t, model_name, train_ds, val_ds, device=device, objective=objective),
304
+ n_trials=n_trials)
305
+
306
+ trials_df = study.trials_dataframe()
307
+ trials_df.to_csv(os.path.join(out_dir, "study_trials.csv"), index=False)
308
+
309
+ best = study.best_trial
310
+ best_params = dict(best.params)
311
+
312
+ # rebuild model from best params
313
+ in_dim = infer_in_dim(train_ds)
314
+ dropout = float(best_params.get("dropout", 0.1))
315
+ if model_name == "mlp":
316
+ model = MLPRegressor(in_dim=in_dim, hidden=int(best_params["hidden"]), dropout=dropout)
317
+ elif model_name == "cnn":
318
+ model = CNNRegressor(in_ch=in_dim, c=int(best_params["channels"]),
319
+ k=int(best_params["kernel"]), layers=int(best_params["layers"]),
320
+ dropout=dropout)
321
+ elif model_name == "transformer":
322
+ model = TransformerRegressor(in_dim=in_dim, d_model=int(best_params["d_model"]),
323
+ nhead=int(best_params["nhead"]), layers=int(best_params["layers"]),
324
+ ff=int(best_params["ff"]), dropout=dropout)
325
+ else:
326
+ raise ValueError(model_name)
327
+
328
+ model = model.to(device)
329
+
330
+ batch_size = int(best_params.get("batch_size", 32))
331
+ train_loader = DataLoader(train_ds, batch_size=batch_size, shuffle=True,
332
+ collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
333
+ val_loader = DataLoader(val_ds, batch_size=64, shuffle=False,
334
+ collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True)
335
+
336
+ # loss
337
+ if best_params.get("loss", "mse") == "mse":
338
+ criterion = nn.MSELoss()
339
+ else:
340
+ criterion = nn.HuberLoss(delta=float(best_params["huber_delta"]))
341
+
342
+ optim = torch.optim.AdamW(model.parameters(), lr=float(best_params["lr"]),
343
+ weight_decay=float(best_params["weight_decay"]))
344
+
345
+ # refit longer with early stopping on the SAME objective
346
+ best_score, bad, patience = -1e18, 0, 15
347
+ best_state = None
348
+
349
+ for epoch in range(1, 201):
350
+ train_one_epoch_reg(model, train_loader, optim, criterion, device)
351
+
352
+ y_true, y_pred = eval_preds(model, val_loader, device)
353
+ metrics = eval_regression(y_true, y_pred)
354
+ score = score_from_metrics(metrics, objective)
355
+
356
+ if score > best_score + 1e-6:
357
+ best_score = score
358
+ bad = 0
359
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
360
+ best_metrics = metrics
361
+ else:
362
+ bad += 1
363
+ if bad >= patience:
364
+ break
365
+
366
+ if best_state is not None:
367
+ model.load_state_dict(best_state)
368
+
369
+ # preds
370
+ y_true_tr, y_pred_tr = eval_preds(model, DataLoader(train_ds, batch_size=64, shuffle=False,
371
+ collate_fn=collate_unpooled_reg, num_workers=4, pin_memory=True), device)
372
+ y_true_va, y_pred_va = eval_preds(model, val_loader, device)
373
+
374
+ seq_train = np.asarray(train_ds["sequence"]) if "sequence" in train_ds.column_names else None
375
+ seq_val = np.asarray(val_ds["sequence"]) if "sequence" in val_ds.column_names else None
376
+ save_predictions_csv(out_dir, "train", y_true_tr, y_pred_tr, seq_train)
377
+ save_predictions_csv(out_dir, "val", y_true_va, y_pred_va, seq_val)
378
+ plot_regression_diagnostics(out_dir, y_true_va, y_pred_va)
379
+
380
+ # save model
381
+ model_path = os.path.join(out_dir, "best_model.pt")
382
+ torch.save({"state_dict": model.state_dict(), "best_params": best_params, "in_dim": in_dim}, model_path)
383
+
384
+ summary = [
385
+ "=" * 72,
386
+ f"MODEL: {model_name}",
387
+ f"OPTUNA objective: {objective} (direction=maximize)",
388
+ f"Best trial: {best.number}",
389
+ "Best val metrics:",
390
+ json.dumps({k: float(v) for k, v in best_metrics.items()}, indent=2),
391
+ f"Saved model: {model_path}",
392
+ "Best params:",
393
+ json.dumps(best_params, indent=2),
394
+ "=" * 72,
395
+ ]
396
+ with open(os.path.join(out_dir, "optimization_summary.txt"), "w") as f:
397
+ f.write("\n".join(summary))
398
+ print("\n".join(summary))
399
+
400
+
401
+ if __name__ == "__main__":
402
+ import argparse
403
+ parser = argparse.ArgumentParser()
404
+ parser.add_argument("--dataset_path", type=str, required=True)
405
+ parser.add_argument("--out_dir", type=str, required=True)
406
+ parser.add_argument("--model", type=str, choices=["mlp","cnn","transformer"], required=True)
407
+ parser.add_argument("--n_trials", type=int, default=80)
408
+ parser.add_argument("--objective", type=str, default="spearman",
409
+ choices=["spearman","neg_rmse","r2"])
410
+ parser.add_argument("--device", type=str, default="cuda:0")
411
+ args = parser.parse_args()
412
+
413
+ run_optuna_and_refit_nn_reg(
414
+ dataset_path=args.dataset_path,
415
+ out_dir=args.out_dir,
416
+ model_name=args.model,
417
+ n_trials=args.n_trials,
418
+ device=args.device,
419
+ objective=args.objective,
420
+ )
training_data_cleaned/data_split.ipynb → training_classifiers/binding_affinity/val_smiles_pooled.csv RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:981339bf1a6594e42a722a42993c238512c3ac572344f68b810f561d4b7b7757
3
- size 228787
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:5410a45a7b65def6cfb94c167b07537abd33b5aac4ecdffe162b7ce4e9bc3d19
3
+ size 36525
training_data_cleaned/nf_smiles_train.csv → training_classifiers/binding_affinity/val_smiles_unpooled.csv RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:1f08b8d9b77fef6da407a6e22765201d8eaf1cff6ae7f0da5d8da261baf64f86
3
- size 2069832
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:cdf71fbb3e7b3b8e8dbfe4ed45b32a2da0049df851f09ee32564825f626cb86c
3
+ size 37187
training_data_cleaned/smiles_data_split.ipynb → training_classifiers/binding_affinity/val_wt_pooled.csv RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:83d55a03c6934dc9ee64f7dbe76d2cf8e042be84b00f8e8bb1c92e2bc6da0c3f
3
- size 2300353
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b194e7b2b97258320323021b3ffe6143133070212a0215ade22fa91b87a3a861
3
+ size 33224
training_data_cleaned/nf_smiles_val.csv → training_classifiers/binding_affinity/val_wt_unpooled.csv RENAMED
@@ -1,3 +1,3 @@
1
  version https://git-lfs.github.com/spec/v1
2
- oid sha256:5a2cf82b6cc31686eff6a7931de34ee0975defc460e470e183b208c0513e5f3b
3
- size 55387144
 
1
  version https://git-lfs.github.com/spec/v1
2
+ oid sha256:051325790047e749fbf1daf7bf25a08178297b0c37acaf9439816d09f2b6c1e3
3
+ size 33826
training_classifiers/binding_affinity/wt_smiles_pooled/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:12f956a7bf04ed602c11fd275377afa73f3f0af1982dbe06c607d8ada304b01c
3
+ size 21617397
training_classifiers/binding_affinity/wt_smiles_pooled/best_params.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "lr": 0.00011987622631192274,
3
+ "weight_decay": 5.279397670067118e-05,
4
+ "dropout": 0.06773313718640918,
5
+ "hidden_dim": 256,
6
+ "n_heads": 8,
7
+ "n_layers": 3,
8
+ "cls_weight": 0.29331613012593555,
9
+ "batch_size": 16
10
+ }
training_classifiers/binding_affinity/wt_smiles_pooled/optuna_trials.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:23161560edf5ad2069302afa1d387819dcfb7010c6ff0437c61ec09a8aa0e0f0
3
+ size 40599
training_classifiers/binding_affinity/wt_smiles_unpooled/.ipynb_checkpoints/best_params-checkpoint.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "lr": 6.714904102732621e-05,
3
+ "weight_decay": 6.94348785472601e-08,
4
+ "dropout": 0.20599610484012826,
5
+ "hidden_dim": 768,
6
+ "n_heads": 4,
7
+ "n_layers": 3,
8
+ "cls_weight": 0.26109289573917854,
9
+ "batch_size": 16
10
+ }
training_classifiers/binding_affinity/wt_smiles_unpooled/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:3d7ae3d2190b034352a65bda1bce86aa5a96ce3daf74cf10a166f8d9e9af51f0
3
+ size 181183221
training_classifiers/binding_affinity/wt_smiles_unpooled/best_params.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "lr": 6.714904102732621e-05,
3
+ "weight_decay": 6.94348785472601e-08,
4
+ "dropout": 0.20599610484012826,
5
+ "hidden_dim": 768,
6
+ "n_heads": 4,
7
+ "n_layers": 3,
8
+ "cls_weight": 0.26109289573917854,
9
+ "batch_size": 16
10
+ }
training_classifiers/binding_affinity/wt_smiles_unpooled/optuna_trials.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:11838c42881182dd76b055f06e89b4423659eacd729d695b8a8f4c0a10165da0
3
+ size 40533
training_classifiers/binding_affinity/wt_wt_pooled/.ipynb_checkpoints/optuna_trials-checkpoint.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b685b92714882d618b42b582000574d83c3be2fbecbec5e0de6b5476948b96c5
3
+ size 40700
training_classifiers/binding_affinity/wt_wt_pooled/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:636de30f4388efd55e8e625c4f2c71a7629982197ef1e53fecf2a4f640df1ae0
3
+ size 182756085
training_classifiers/binding_affinity/wt_wt_pooled/best_params.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "lr": 0.0001730381005812531,
3
+ "weight_decay": 8.736709570411299e-05,
4
+ "dropout": 0.17533811687195416,
5
+ "hidden_dim": 768,
6
+ "n_heads": 4,
7
+ "n_layers": 3,
8
+ "cls_weight": 0.1278591739909013,
9
+ "batch_size": 16
10
+ }
training_classifiers/binding_affinity/wt_wt_pooled/optuna_trials.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b685b92714882d618b42b582000574d83c3be2fbecbec5e0de6b5476948b96c5
3
+ size 40700
training_classifiers/binding_affinity/wt_wt_unpooled/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ce8da2406036535794909ffde1f1941843096b7d3c71ba772f9170ab123877f2
3
+ size 69333557
training_classifiers/binding_affinity/wt_wt_unpooled/best_params.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "lr": 0.000657577559506255,
3
+ "weight_decay": 3.3209159985473103e-07,
4
+ "dropout": 0.16430662769055482,
5
+ "hidden_dim": 768,
6
+ "n_heads": 8,
7
+ "n_layers": 1,
8
+ "cls_weight": 0.7037398702018655,
9
+ "batch_size": 16
10
+ }
training_classifiers/binding_affinity/wt_wt_unpooled/optuna_trials.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:b330d37733e684ff780b143f17fd26c498f615dfbab8c5b3df08eae7eb019139
3
+ size 40587
training_classifiers/binding_affinity_iptm.py ADDED
@@ -0,0 +1,132 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ """
3
+ extract_iptm_affinity_csv_all.py
4
+
5
+ Writes:
6
+ - out_dir/wt_iptm_affinity_all.csv
7
+ - out_dir/smiles_iptm_affinity_all.csv
8
+
9
+ Also prints:
10
+ - N
11
+ - Spearman rho (affinity vs iptm)
12
+ - Pearson r (affinity vs iptm)
13
+ """
14
+
15
+ from pathlib import Path
16
+ import numpy as np
17
+ import pandas as pd
18
+
19
+
20
+ def corr_stats(df: pd.DataFrame, x: str, y: str):
21
+ # pandas handles NaNs if we already dropped them; still be safe
22
+ xx = pd.to_numeric(df[x], errors="coerce")
23
+ yy = pd.to_numeric(df[y], errors="coerce")
24
+ m = xx.notna() & yy.notna()
25
+ xx = xx[m]
26
+ yy = yy[m]
27
+ n = int(m.sum())
28
+
29
+ # Pearson r
30
+ pearson_r = float(xx.corr(yy, method="pearson")) if n > 1 else float("nan")
31
+ # Spearman rho
32
+ spearman_rho = float(xx.corr(yy, method="spearman")) if n > 1 else float("nan")
33
+
34
+ return {"n": n, "pearson_r": pearson_r, "spearman_rho": spearman_rho}
35
+
36
+
37
+ def clean_one(
38
+ in_csv: Path,
39
+ out_csv: Path,
40
+ iptm_col: str,
41
+ affinity_col: str = "affinity",
42
+ keep_cols=(),
43
+ ):
44
+ df = pd.read_csv(in_csv)
45
+
46
+ # affinity + iptm must exist
47
+ need = [affinity_col, iptm_col]
48
+ missing = [c for c in need if c not in df.columns]
49
+ if missing:
50
+ raise ValueError(f"{in_csv} missing columns: {missing}. Found: {list(df.columns)}")
51
+
52
+ # coerce numeric
53
+ df[affinity_col] = pd.to_numeric(df[affinity_col], errors="coerce")
54
+ df[iptm_col] = pd.to_numeric(df[iptm_col], errors="coerce")
55
+
56
+ # drop NaNs in either
57
+ df = df.dropna(subset=[affinity_col, iptm_col]).reset_index(drop=True)
58
+
59
+ # output cols (standardize names)
60
+ out = pd.DataFrame({
61
+ "affinity": df[affinity_col].astype(float),
62
+ "iptm": df[iptm_col].astype(float),
63
+ })
64
+
65
+ # keep split if present (handy for coloring later, but not used for corr)
66
+ if "split" in df.columns:
67
+ out.insert(0, "split", df["split"].astype(str))
68
+
69
+ # optional extras for labeling/debug
70
+ for c in keep_cols:
71
+ if c in df.columns:
72
+ out[c] = df[c]
73
+
74
+ out_csv.parent.mkdir(parents=True, exist_ok=True)
75
+ out.to_csv(out_csv, index=False)
76
+
77
+ stats = corr_stats(out, "iptm", "affinity")
78
+ print(f"[write] {out_csv}")
79
+ print(f" N={stats['n']} | Pearson r={stats['pearson_r']:.4f} | Spearman rho={stats['spearman_rho']:.4f}")
80
+
81
+ # also save stats json next to csv
82
+ stats_path = out_csv.with_suffix(".stats.json")
83
+ with open(stats_path, "w") as f:
84
+ import json
85
+ json.dump(
86
+ {
87
+ "input_csv": str(in_csv),
88
+ "output_csv": str(out_csv),
89
+ "iptm_col": iptm_col,
90
+ "affinity_col": affinity_col,
91
+ **stats,
92
+ },
93
+ f,
94
+ indent=2,
95
+ )
96
+
97
+
98
+ def main():
99
+ import argparse
100
+ ap = argparse.ArgumentParser()
101
+ ap.add_argument("--wt_meta_csv", type=str, required=True)
102
+ ap.add_argument("--smiles_meta_csv", type=str, required=True)
103
+ ap.add_argument("--out_dir", type=str, required=True)
104
+
105
+ ap.add_argument("--wt_iptm_col", type=str, default="wt_iptm_score")
106
+ ap.add_argument("--smiles_iptm_col", type=str, default="smiles_iptm_score")
107
+ ap.add_argument("--affinity_col", type=str, default="affinity")
108
+ args = ap.parse_args()
109
+
110
+ out_dir = Path(args.out_dir)
111
+
112
+ clean_one(
113
+ Path(args.wt_meta_csv),
114
+ out_dir / "wt_iptm_affinity_all.csv",
115
+ iptm_col=args.wt_iptm_col,
116
+ affinity_col=args.affinity_col,
117
+ keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES"),
118
+ )
119
+
120
+ clean_one(
121
+ Path(args.smiles_meta_csv),
122
+ out_dir / "smiles_iptm_affinity_all.csv",
123
+ iptm_col=args.smiles_iptm_col,
124
+ affinity_col=args.affinity_col,
125
+ keep_cols=("seq1", "seq2", "Fasta2SMILES", "REACT_SMILES", "smiles_sequence"),
126
+ )
127
+
128
+ print(f"\n[DONE] CSVs + stats JSONs in: {out_dir}")
129
+
130
+
131
+ if __name__ == "__main__":
132
+ main()
training_classifiers/binding_affinity_split.py ADDED
@@ -0,0 +1,847 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python3
2
+ import os
3
+ import math
4
+ from pathlib import Path
5
+ import sys
6
+ from contextlib import contextmanager
7
+
8
+ import numpy as np
9
+ import pandas as pd
10
+ import torch
11
+
12
+ # tqdm is optional; we’ll disable it by default in notebooks
13
+ from tqdm import tqdm
14
+
15
+ sys.path.append("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight")
16
+ from tokenizer.my_tokenizers import SMILES_SPE_Tokenizer
17
+
18
+ from datasets import Dataset, DatasetDict, Features, Value, Sequence as HFSequence
19
+ from transformers import AutoTokenizer, EsmModel, AutoModelForMaskedLM
20
+
21
+ # -------------------------
22
+ # Config
23
+ # -------------------------
24
+ CSV_PATH = Path("/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/c-binding_with_openfold_scores.csv")
25
+
26
+ OUT_ROOT = Path(
27
+ "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_data_cleaned/binding_affinity"
28
+ )
29
+
30
+ # WT (seq) embedding model
31
+ WT_MODEL_NAME = "facebook/esm2_t33_650M_UR50D"
32
+ WT_MAX_LEN = 1022
33
+ WT_BATCH = 32
34
+
35
+ # SMILES embedding model + tokenizer
36
+ SMI_MODEL_NAME = "aaronfeller/PeptideCLM-23M-all"
37
+ TOKENIZER_VOCAB = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_vocab.txt"
38
+ TOKENIZER_SPLITS = "/vast/projects/pranam/lab/yz927/projects/Classifier_Weight/tokenizer/new_splits.txt"
39
+ SMI_MAX_LEN = 768
40
+ SMI_BATCH = 128
41
+
42
+ # Split config
43
+ TRAIN_FRAC = 0.80
44
+ RANDOM_SEED = 1986
45
+ AFFINITY_Q_BINS = 30
46
+
47
+ # Columns expected in CSV
48
+ COL_SEQ1 = "seq1"
49
+ COL_SEQ2 = "seq2"
50
+ COL_AFF = "affinity"
51
+ COL_F2S = "Fasta2SMILES"
52
+ COL_REACT = "REACT_SMILES"
53
+ COL_WT_IPTM = "wt_iptm_score"
54
+ COL_SMI_IPTM = "smiles_iptm_score"
55
+
56
+ # Device
57
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
58
+
59
+ # -------------------------
60
+ # Quiet / notebook-safe output controls
61
+ # -------------------------
62
+ QUIET = True # suppress most prints
63
+ USE_TQDM = False # disable tqdm bars (recommended in Jupyter to avoid crashing)
64
+ LOG_FILE = None # optionally: OUT_ROOT / "build.log"
65
+
66
+ def log(msg: str):
67
+ if LOG_FILE is not None:
68
+ Path(LOG_FILE).parent.mkdir(parents=True, exist_ok=True)
69
+ with open(LOG_FILE, "a") as f:
70
+ f.write(msg.rstrip() + "\n")
71
+ if not QUIET:
72
+ print(msg)
73
+
74
+ def pbar(it, **kwargs):
75
+ return tqdm(it, **kwargs) if USE_TQDM else it
76
+
77
+ @contextmanager
78
+ def section(title: str):
79
+ log(f"\n=== {title} ===")
80
+ yield
81
+ log(f"=== done: {title} ===")
82
+
83
+
84
+ # -------------------------
85
+ # Helpers
86
+ # -------------------------
87
+ def has_uaa(seq: str) -> bool:
88
+ return "X" in str(seq).upper()
89
+
90
+ def affinity_to_class(a: float) -> str:
91
+ # High: >= 9 ; Moderate: [7, 9) ; Low: < 7
92
+ if a >= 9.0:
93
+ return "High"
94
+ elif a >= 7.0:
95
+ return "Moderate"
96
+ else:
97
+ return "Low"
98
+
99
+ def make_distribution_matched_split(df: pd.DataFrame) -> pd.DataFrame:
100
+ df = df.copy()
101
+
102
+ df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
103
+ df = df.dropna(subset=[COL_AFF]).reset_index(drop=True)
104
+
105
+ df["affinity_class"] = df[COL_AFF].apply(affinity_to_class)
106
+
107
+ try:
108
+ df["aff_bin"] = pd.qcut(df[COL_AFF], q=AFFINITY_Q_BINS, duplicates="drop")
109
+ strat_col = "aff_bin"
110
+ except Exception:
111
+ df["aff_bin"] = df["affinity_class"]
112
+ strat_col = "aff_bin"
113
+
114
+ rng = np.random.RandomState(RANDOM_SEED)
115
+
116
+ df["split"] = None
117
+ for _, g in df.groupby(strat_col, observed=True):
118
+ idx = g.index.to_numpy()
119
+ rng.shuffle(idx)
120
+ n_train = int(math.floor(len(idx) * TRAIN_FRAC))
121
+ df.loc[idx[:n_train], "split"] = "train"
122
+ df.loc[idx[n_train:], "split"] = "val"
123
+
124
+ df["split"] = df["split"].fillna("train")
125
+ return df
126
+
127
+ def _summ(x):
128
+ x = np.asarray(x, dtype=float)
129
+ x = x[~np.isnan(x)]
130
+ if len(x) == 0:
131
+ return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
132
+ return {
133
+ "n": int(len(x)),
134
+ "mean": float(np.mean(x)),
135
+ "std": float(np.std(x)),
136
+ "p50": float(np.quantile(x, 0.50)),
137
+ "p95": float(np.quantile(x, 0.95)),
138
+ }
139
+
140
+ def _len_stats(seqs):
141
+ lens = np.asarray([len(str(s)) for s in seqs], dtype=float)
142
+ if len(lens) == 0:
143
+ return {"n": 0, "mean": np.nan, "std": np.nan, "p50": np.nan, "p95": np.nan}
144
+ return {
145
+ "n": int(len(lens)),
146
+ "mean": float(lens.mean()),
147
+ "std": float(lens.std()),
148
+ "p50": float(np.quantile(lens, 0.50)),
149
+ "p95": float(np.quantile(lens, 0.95)),
150
+ }
151
+
152
+ def verify_split_before_embedding(
153
+ df2: pd.DataFrame,
154
+ affinity_col: str,
155
+ split_col: str,
156
+ seq_col: str,
157
+ iptm_col: str,
158
+ aff_class_col: str = "affinity_class",
159
+ aff_bins: int = 30,
160
+ save_report_prefix: str | None = None,
161
+ verbose: bool = False,
162
+ ):
163
+ """
164
+ Notebook-safe: by default prints only ONE line via `log()`.
165
+ Optionally writes CSV reports (stats + class proportions).
166
+ """
167
+ df2 = df2.copy()
168
+ df2[affinity_col] = pd.to_numeric(df2[affinity_col], errors="coerce")
169
+ df2[iptm_col] = pd.to_numeric(df2[iptm_col], errors="coerce")
170
+
171
+ assert split_col in df2.columns, f"Missing split col: {split_col}"
172
+ assert set(df2[split_col].dropna().unique()).issubset({"train", "val"}), f"Unexpected split values: {df2[split_col].unique()}"
173
+ assert df2[affinity_col].notna().any(), "No valid affinity values after coercion."
174
+
175
+ try:
176
+ df2["_aff_bin_dbg"] = pd.qcut(df2[affinity_col], q=aff_bins, duplicates="drop")
177
+ except Exception:
178
+ df2["_aff_bin_dbg"] = df2[aff_class_col].astype(str)
179
+
180
+ tr = df2[df2[split_col] == "train"].reset_index(drop=True)
181
+ va = df2[df2[split_col] == "val"].reset_index(drop=True)
182
+
183
+ tr_aff = _summ(tr[affinity_col].to_numpy())
184
+ va_aff = _summ(va[affinity_col].to_numpy())
185
+ tr_len = _len_stats(tr[seq_col].tolist())
186
+ va_len = _len_stats(va[seq_col].tolist())
187
+
188
+ # bin drift
189
+ bin_ct = (
190
+ df2.groupby([split_col, "_aff_bin_dbg"])
191
+ .size()
192
+ .groupby(level=0)
193
+ .apply(lambda s: s / s.sum())
194
+ )
195
+ tr_bins = bin_ct.loc["train"]
196
+ va_bins = bin_ct.loc["val"]
197
+ all_bins = tr_bins.index.union(va_bins.index)
198
+ tr_bins = tr_bins.reindex(all_bins, fill_value=0.0)
199
+ va_bins = va_bins.reindex(all_bins, fill_value=0.0)
200
+ max_bin_diff = float(np.max(np.abs(tr_bins.values - va_bins.values)))
201
+
202
+ msg = (
203
+ f"[split-check] rows={len(df2)} train={len(tr)} val={len(va)} | "
204
+ f"aff(mean±std) train={tr_aff['mean']:.3f}±{tr_aff['std']:.3f} val={va_aff['mean']:.3f}±{va_aff['std']:.3f} | "
205
+ f"len(p50/p95) train={tr_len['p50']:.1f}/{tr_len['p95']:.1f} val={va_len['p50']:.1f}/{va_len['p95']:.1f} | "
206
+ f"max_bin_diff={max_bin_diff:.4f}"
207
+ )
208
+ log(msg)
209
+
210
+ if verbose and (not QUIET):
211
+ class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
212
+ class_prop = class_ct.div(class_ct.sum(axis=1), axis=0)
213
+ print("\n[verbose] affinity_class counts:\n", class_ct)
214
+ print("\n[verbose] affinity_class proportions:\n", class_prop.round(4))
215
+
216
+ if save_report_prefix is not None:
217
+ out = Path(save_report_prefix)
218
+ out.parent.mkdir(parents=True, exist_ok=True)
219
+
220
+ stats_df = pd.DataFrame([
221
+ {"split": "train", **{f"aff_{k}": v for k, v in tr_aff.items()}, **{f"len_{k}": v for k, v in tr_len.items()}},
222
+ {"split": "val", **{f"aff_{k}": v for k, v in va_aff.items()}, **{f"len_{k}": v for k, v in va_len.items()}},
223
+ ])
224
+ class_ct = df2.groupby([split_col, aff_class_col]).size().unstack(fill_value=0)
225
+ class_prop = class_ct.div(class_ct.sum(axis=1), axis=0).reset_index()
226
+
227
+ stats_df.to_csv(out.with_suffix(".stats.csv"), index=False)
228
+ class_prop.to_csv(out.with_suffix(".class_prop.csv"), index=False)
229
+
230
+
231
+ # -------------------------
232
+ # WT pooled (ESM2)
233
+ # -------------------------
234
+ @torch.no_grad()
235
+ def wt_pooled_embeddings(seqs, tokenizer, model, batch_size=32, max_length=1022):
236
+ embs = []
237
+ for i in pbar(range(0, len(seqs), batch_size)):
238
+ batch = seqs[i:i + batch_size]
239
+ inputs = tokenizer(
240
+ batch,
241
+ padding=True,
242
+ truncation=True,
243
+ max_length=max_length,
244
+ return_tensors="pt",
245
+ )
246
+ inputs = {k: v.to(DEVICE) for k, v in inputs.items()}
247
+ out = model(**inputs)
248
+ h = out.last_hidden_state # (B, L, H)
249
+
250
+ attn = inputs["attention_mask"].unsqueeze(-1) # (B, L, 1)
251
+ summed = (h * attn).sum(dim=1) # (B, H)
252
+ denom = attn.sum(dim=1).clamp(min=1e-9) # (B, 1)
253
+ pooled = (summed / denom).detach().cpu().numpy()
254
+ embs.append(pooled)
255
+
256
+ return np.vstack(embs)
257
+
258
+
259
+ # -------------------------
260
+ # WT unpooled (ESM2)
261
+ # -------------------------
262
+ @torch.no_grad()
263
+ def wt_unpooled_one(seq, tokenizer, model, cls_id, eos_id, max_length=1022):
264
+ tok = tokenizer(seq, padding=False, truncation=True, max_length=max_length, return_tensors="pt")
265
+ tok = {k: v.to(DEVICE) for k, v in tok.items()}
266
+ out = model(**tok)
267
+ h = out.last_hidden_state[0] # (L, H)
268
+ attn = tok["attention_mask"][0].bool() # (L,)
269
+ ids = tok["input_ids"][0]
270
+
271
+ keep = attn.clone()
272
+ if cls_id is not None:
273
+ keep &= (ids != cls_id)
274
+ if eos_id is not None:
275
+ keep &= (ids != eos_id)
276
+
277
+ return h[keep].detach().cpu().to(torch.float16).numpy()
278
+
279
+ def build_wt_unpooled_dataset(df_split: pd.DataFrame, out_dir: Path, tokenizer, model):
280
+ """
281
+ Expects df_split to have:
282
+ - target_sequence (seq1)
283
+ - sequence (binder seq2; WT binder)
284
+ - label, affinity_class, COL_AFF, COL_WT_IPTM
285
+ Saves a dataset where each row contains BOTH:
286
+ - target_embedding (Lt,H), target_attention_mask, target_length
287
+ - binder_embedding (Lb,H), binder_attention_mask, binder_length
288
+ """
289
+ cls_id = tokenizer.cls_token_id
290
+ eos_id = tokenizer.eos_token_id
291
+ H = model.config.hidden_size
292
+
293
+ features = Features({
294
+ "target_sequence": Value("string"),
295
+ "sequence": Value("string"),
296
+ "label": Value("float32"),
297
+ "affinity": Value("float32"),
298
+ "affinity_class": Value("string"),
299
+
300
+ "target_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
301
+ "target_attention_mask": HFSequence(Value("int8")),
302
+ "target_length": Value("int64"),
303
+
304
+ "binder_embedding": HFSequence(HFSequence(Value("float16"), length=H)),
305
+ "binder_attention_mask": HFSequence(Value("int8")),
306
+ "binder_length": Value("int64"),
307
+
308
+ COL_WT_IPTM: Value("float32"),
309
+ COL_AFF: Value("float32"),
310
+ })
311
+
312
+ def gen_rows(df: pd.DataFrame):
313
+ for r in pbar(df.itertuples(index=False), total=len(df)):
314
+ tgt = str(getattr(r, "target_sequence")).strip()
315
+ bnd = str(getattr(r, "sequence")).strip()
316
+
317
+ y = float(getattr(r, "label"))
318
+ aff = float(getattr(r, COL_AFF))
319
+ acls = str(getattr(r, "affinity_class"))
320
+
321
+ iptm = getattr(r, COL_WT_IPTM)
322
+ iptm = float(iptm) if pd.notna(iptm) else np.nan
323
+
324
+ # token embeddings for target + binder (both ESM)
325
+ t_emb = wt_unpooled_one(tgt, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lt,H)
326
+ b_emb = wt_unpooled_one(bnd, tokenizer, model, cls_id, eos_id, max_length=WT_MAX_LEN) # (Lb,H)
327
+
328
+ t_list = t_emb.tolist()
329
+ b_list = b_emb.tolist()
330
+ Lt = len(t_list)
331
+ Lb = len(b_list)
332
+
333
+ yield {
334
+ "target_sequence": tgt,
335
+ "sequence": bnd,
336
+ "label": np.float32(y),
337
+ "affinity": np.float32(aff),
338
+ "affinity_class": acls,
339
+
340
+ "target_embedding": t_list,
341
+ "target_attention_mask": [1] * Lt,
342
+ "target_length": int(Lt),
343
+
344
+ "binder_embedding": b_list,
345
+ "binder_attention_mask": [1] * Lb,
346
+ "binder_length": int(Lb),
347
+
348
+ COL_WT_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
349
+ COL_AFF: np.float32(aff),
350
+ }
351
+
352
+ out_dir.mkdir(parents=True, exist_ok=True)
353
+ ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
354
+ ds.save_to_disk(str(out_dir), max_shard_size="1GB")
355
+ return ds
356
+
357
+ def build_smiles_unpooled_paired_dataset(df_split: pd.DataFrame, out_dir: Path, wt_tokenizer, wt_model_unpooled,
358
+ smi_tok, smi_roformer):
359
+ """
360
+ df_split must have:
361
+ - target_sequence (seq1)
362
+ - sequence (binder smiles string)
363
+ - label, affinity_class, COL_AFF, COL_SMI_IPTM
364
+ Saves rows with:
365
+ target_embedding (Lt,Ht) from ESM
366
+ binder_embedding (Lb,Hb) from PeptideCLM
367
+ """
368
+ cls_id = wt_tokenizer.cls_token_id
369
+ eos_id = wt_tokenizer.eos_token_id
370
+ Ht = wt_model_unpooled.config.hidden_size
371
+
372
+ # Infer Hb from one forward pass? easiest: run one mini batch outside in main if you want.
373
+ # Here: we’ll infer from model config if available.
374
+ Hb = getattr(smi_roformer.config, "hidden_size", None)
375
+ if Hb is None:
376
+ Hb = getattr(smi_roformer.config, "dim", None)
377
+ if Hb is None:
378
+ raise ValueError("Cannot infer Hb from smi_roformer config; print(smi_roformer.config) and set Hb manually.")
379
+
380
+ features = Features({
381
+ "target_sequence": Value("string"),
382
+ "sequence": Value("string"),
383
+ "label": Value("float32"),
384
+ "affinity": Value("float32"),
385
+ "affinity_class": Value("string"),
386
+
387
+ "target_embedding": HFSequence(HFSequence(Value("float16"), length=Ht)),
388
+ "target_attention_mask": HFSequence(Value("int8")),
389
+ "target_length": Value("int64"),
390
+
391
+ "binder_embedding": HFSequence(HFSequence(Value("float16"), length=Hb)),
392
+ "binder_attention_mask": HFSequence(Value("int8")),
393
+ "binder_length": Value("int64"),
394
+
395
+ COL_SMI_IPTM: Value("float32"),
396
+ COL_AFF: Value("float32"),
397
+ })
398
+
399
+ def gen_rows(df: pd.DataFrame):
400
+ for r in pbar(df.itertuples(index=False), total=len(df)):
401
+ tgt = str(getattr(r, "target_sequence")).strip()
402
+ bnd = str(getattr(r, "sequence")).strip()
403
+
404
+ y = float(getattr(r, "label"))
405
+ aff = float(getattr(r, COL_AFF))
406
+ acls = str(getattr(r, "affinity_class"))
407
+
408
+ iptm = getattr(r, COL_SMI_IPTM)
409
+ iptm = float(iptm) if pd.notna(iptm) else np.nan
410
+
411
+ # target token embeddings (ESM)
412
+ t_emb = wt_unpooled_one(tgt, wt_tokenizer, wt_model_unpooled, cls_id, eos_id, max_length=WT_MAX_LEN)
413
+ t_list = t_emb.tolist()
414
+ Lt = len(t_list)
415
+
416
+ # binder token embeddings (PeptideCLM) — single-item batch
417
+ _, tok_list, mask_list, lengths = smiles_embed_batch_return_both(
418
+ [bnd], smi_tok, smi_roformer, max_length=SMI_MAX_LEN
419
+ )
420
+ b_emb = tok_list[0] # np.float16 (Lb, Hb)
421
+ b_list = b_emb.tolist()
422
+ Lb = int(lengths[0])
423
+ b_mask = mask_list[0].astype(np.int8).tolist()
424
+
425
+ yield {
426
+ "target_sequence": tgt,
427
+ "sequence": bnd,
428
+ "label": np.float32(y),
429
+ "affinity": np.float32(aff),
430
+ "affinity_class": acls,
431
+
432
+ "target_embedding": t_list,
433
+ "target_attention_mask": [1] * Lt,
434
+ "target_length": int(Lt),
435
+
436
+ "binder_embedding": b_list,
437
+ "binder_attention_mask": [int(x) for x in b_mask],
438
+ "binder_length": int(Lb),
439
+
440
+ COL_SMI_IPTM: np.float32(iptm) if not np.isnan(iptm) else np.float32(np.nan),
441
+ COL_AFF: np.float32(aff),
442
+ }
443
+
444
+ out_dir.mkdir(parents=True, exist_ok=True)
445
+ ds = Dataset.from_generator(lambda: gen_rows(df_split), features=features)
446
+ ds.save_to_disk(str(out_dir), max_shard_size="1GB")
447
+ return ds
448
+
449
+
450
+ # -------------------------
451
+ # SMILES pooled + unpooled (PeptideCLM)
452
+ # -------------------------
453
+ def get_special_ids(tokenizer_obj):
454
+ cand = [
455
+ getattr(tokenizer_obj, "pad_token_id", None),
456
+ getattr(tokenizer_obj, "cls_token_id", None),
457
+ getattr(tokenizer_obj, "sep_token_id", None),
458
+ getattr(tokenizer_obj, "bos_token_id", None),
459
+ getattr(tokenizer_obj, "eos_token_id", None),
460
+ getattr(tokenizer_obj, "mask_token_id", None),
461
+ ]
462
+ return sorted({x for x in cand if x is not None})
463
+
464
+ @torch.no_grad()
465
+ def smiles_embed_batch_return_both(batch_sequences, tokenizer_obj, model_roformer, max_length):
466
+ tok = tokenizer_obj(
467
+ batch_sequences,
468
+ return_tensors="pt",
469
+ padding=True,
470
+ truncation=True,
471
+ max_length=max_length,
472
+ )
473
+ input_ids = tok["input_ids"].to(DEVICE)
474
+ attention_mask = tok["attention_mask"].to(DEVICE)
475
+
476
+ outputs = model_roformer(input_ids=input_ids, attention_mask=attention_mask)
477
+ last_hidden = outputs.last_hidden_state # (B, L, H)
478
+
479
+ special_ids = get_special_ids(tokenizer_obj)
480
+ valid = attention_mask.bool()
481
+ if len(special_ids) > 0:
482
+ sid = torch.tensor(special_ids, device=DEVICE, dtype=torch.long)
483
+ if hasattr(torch, "isin"):
484
+ valid = valid & (~torch.isin(input_ids, sid))
485
+ else:
486
+ m = torch.zeros_like(valid)
487
+ for s in special_ids:
488
+ m |= (input_ids == s)
489
+ valid = valid & (~m)
490
+
491
+ valid_f = valid.unsqueeze(-1).float()
492
+ summed = torch.sum(last_hidden * valid_f, dim=1)
493
+ denom = torch.clamp(valid_f.sum(dim=1), min=1e-9)
494
+ pooled = (summed / denom).detach().cpu().numpy()
495
+
496
+ token_emb_list, mask_list, lengths = [], [], []
497
+ for b in range(last_hidden.shape[0]):
498
+ emb = last_hidden[b, valid[b]] # (Li, H)
499
+ token_emb_list.append(emb.detach().cpu().to(torch.float16).numpy())
500
+ li = emb.shape[0]
501
+ lengths.append(int(li))
502
+ mask_list.append(np.ones((li,), dtype=np.int8))
503
+
504
+ return pooled, token_emb_list, mask_list, lengths
505
+
506
+ def smiles_generate_embeddings_batched_both(seqs, tokenizer_obj, model_roformer, batch_size, max_length):
507
+ pooled_all = []
508
+ token_emb_all = []
509
+ mask_all = []
510
+ lengths_all = []
511
+
512
+ for i in pbar(range(0, len(seqs), batch_size)):
513
+ batch = seqs[i:i + batch_size]
514
+ pooled, tok_list, m_list, lens = smiles_embed_batch_return_both(
515
+ batch, tokenizer_obj, model_roformer, max_length
516
+ )
517
+ pooled_all.append(pooled)
518
+ token_emb_all.extend(tok_list)
519
+ mask_all.extend(m_list)
520
+ lengths_all.extend(lens)
521
+
522
+ return np.vstack(pooled_all), token_emb_all, mask_all, lengths_all
523
+
524
+ # -------------------------
525
+ # Target embedding cache (NO extra ESM runs)
526
+ # We will compute target pooled embeddings ONCE from WT view, then reuse for SMILES.
527
+ # -------------------------
528
+ def build_target_cache_from_wt_view(wt_view_train: pd.DataFrame, wt_view_val: pd.DataFrame):
529
+ wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
530
+ wt_model = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
531
+
532
+ # compute target pooled embeddings once
533
+ tgt_wt_train = wt_view_train["target_sequence"].astype(str).tolist()
534
+ tgt_wt_val = wt_view_val["target_sequence"].astype(str).tolist()
535
+
536
+ wt_train_tgt_emb = wt_pooled_embeddings(
537
+ tgt_wt_train, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
538
+ )
539
+ wt_val_tgt_emb = wt_pooled_embeddings(
540
+ tgt_wt_val, wt_tok, wt_model, batch_size=WT_BATCH, max_length=WT_MAX_LEN
541
+ )
542
+
543
+ # build dict: target_sequence -> embedding (float32 array)
544
+ # if duplicates exist, last wins; you can add checks if needed
545
+ train_map = {s: e for s, e in zip(tgt_wt_train, wt_train_tgt_emb)}
546
+ val_map = {s: e for s, e in zip(tgt_wt_val, wt_val_tgt_emb)}
547
+ return wt_tok, wt_model, wt_train_tgt_emb, wt_val_tgt_emb, train_map, val_map
548
+ # -------------------------
549
+ # Main
550
+ # -------------------------
551
+ def main():
552
+ log(f"[INFO] DEVICE: {DEVICE}")
553
+ OUT_ROOT.mkdir(parents=True, exist_ok=True)
554
+
555
+ # 1) Load
556
+ with section("load csv + dedup"):
557
+ df = pd.read_csv(CSV_PATH)
558
+ for c in [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]:
559
+ if c in df.columns:
560
+ df[c] = df[c].apply(lambda x: x.strip() if isinstance(x, str) else x)
561
+
562
+ # Dedup on the full identity tuple you want
563
+ DEDUP_COLS = [COL_SEQ1, COL_SEQ2, COL_F2S, COL_REACT]
564
+ df = df.drop_duplicates(subset=DEDUP_COLS).reset_index(drop=True)
565
+
566
+ print("Rows after dedup on", DEDUP_COLS, ":", len(df))
567
+
568
+ need = [COL_SEQ1, COL_SEQ2, COL_AFF, COL_F2S, COL_REACT, COL_WT_IPTM, COL_SMI_IPTM]
569
+ missing = [c for c in need if c not in df.columns]
570
+ if missing:
571
+ raise ValueError(f"Missing required columns: {missing}")
572
+
573
+ # numeric affinity for both branches
574
+ df[COL_AFF] = pd.to_numeric(df[COL_AFF], errors="coerce")
575
+
576
+ # 2) Build WT subset + SMILES subset separately (NO global dropping)
577
+ with section("prepare wt/smiles subsets"):
578
+ # WT: requires a canonical peptide sequence (no X) + affinity
579
+ df_wt = df.copy()
580
+ df_wt["wt_sequence"] = df_wt[COL_SEQ2].astype(str).str.strip()
581
+ df_wt = df_wt.dropna(subset=[COL_AFF]).reset_index(drop=True)
582
+ df_wt = df_wt[df_wt["wt_sequence"].notna() & (df_wt["wt_sequence"] != "")]
583
+ df_wt = df_wt[~df_wt["wt_sequence"].str.contains("X", case=False, na=False)].reset_index(drop=True)
584
+
585
+ # SMILES: requires affinity + a usable picked SMILES (UAA->REACT, else->Fasta2SMILES)
586
+ df_smi = df.copy()
587
+ df_smi = df_smi.dropna(subset=[COL_AFF]).reset_index(drop=True)
588
+ df_smi = df_smi[
589
+ pd.to_numeric(df_smi[COL_SMI_IPTM], errors="coerce").notna()
590
+ ].reset_index(drop=True) # empty iptm means sth wrong with their smiles sequenc
591
+
592
+ is_uaa = df_smi[COL_SEQ2].astype(str).str.contains("X", case=False, na=False)
593
+ df_smi["smiles_sequence"] = np.where(is_uaa, df_smi[COL_REACT], df_smi[COL_F2S])
594
+ df_smi["smiles_sequence"] = df_smi["smiles_sequence"].astype(str).str.strip()
595
+ df_smi = df_smi[df_smi["smiles_sequence"].notna() & (df_smi["smiles_sequence"] != "")]
596
+ df_smi = df_smi[~df_smi["smiles_sequence"].isin(["nan", "None"])].reset_index(drop=True)
597
+
598
+ log(f"[counts] WT rows={len(df_wt)} | SMILES rows={len(df_smi)} (after per-branch filtering)")
599
+
600
+ # 3) Split separately (different sizes and memberships are expected)
601
+ with section("split wt and smiles separately"):
602
+ df_wt2 = make_distribution_matched_split(df_wt)
603
+ df_smi2 = make_distribution_matched_split(df_smi)
604
+
605
+ # save split tables
606
+ wt_split_csv = OUT_ROOT / "binding_affinity_wt_meta_with_split.csv"
607
+ smi_split_csv = OUT_ROOT / "binding_affinity_smiles_meta_with_split.csv"
608
+ df_wt2.to_csv(wt_split_csv, index=False)
609
+ df_smi2.to_csv(smi_split_csv, index=False)
610
+ log(f"Saved WT split meta: {wt_split_csv}")
611
+ log(f"Saved SMILES split meta: {smi_split_csv}")
612
+
613
+ # lightweight double-check (one-line)
614
+ verify_split_before_embedding(
615
+ df2=df_wt2,
616
+ affinity_col=COL_AFF,
617
+ split_col="split",
618
+ seq_col="wt_sequence",
619
+ iptm_col=COL_WT_IPTM,
620
+ aff_class_col="affinity_class",
621
+ aff_bins=AFFINITY_Q_BINS,
622
+ save_report_prefix=str(OUT_ROOT / "wt_split_doublecheck_report"),
623
+ verbose=False,
624
+ )
625
+ verify_split_before_embedding(
626
+ df2=df_smi2,
627
+ affinity_col=COL_AFF,
628
+ split_col="split",
629
+ seq_col="smiles_sequence",
630
+ iptm_col=COL_SMI_IPTM,
631
+ aff_class_col="affinity_class",
632
+ aff_bins=AFFINITY_Q_BINS,
633
+ save_report_prefix=str(OUT_ROOT / "smiles_split_doublecheck_report"),
634
+ verbose=False,
635
+ )
636
+
637
+ # Prepare split views
638
+ def prep_view(df_in: pd.DataFrame, binder_seq_col: str, iptm_col: str) -> pd.DataFrame:
639
+ out = df_in.copy()
640
+ out["target_sequence"] = out[COL_SEQ1].astype(str).str.strip() # <-- NEW
641
+ out["sequence"] = out[binder_seq_col].astype(str).str.strip() # binder
642
+ out["label"] = pd.to_numeric(out[COL_AFF], errors="coerce")
643
+ out[iptm_col] = pd.to_numeric(out[iptm_col], errors="coerce")
644
+ out[COL_AFF] = pd.to_numeric(out[COL_AFF], errors="coerce")
645
+ out = out.dropna(subset=["target_sequence", "sequence", "label"]).reset_index(drop=True)
646
+ return out[["target_sequence", "sequence", "label", "split", iptm_col, COL_AFF, "affinity_class"]]
647
+
648
+ wt_view = prep_view(df_wt2, "wt_sequence", COL_WT_IPTM)
649
+ smi_view = prep_view(df_smi2, "smiles_sequence", COL_SMI_IPTM)
650
+
651
+ # -------------------------
652
+ # Split views
653
+ # -------------------------
654
+ wt_train = wt_view[wt_view["split"] == "train"].reset_index(drop=True)
655
+ wt_val = wt_view[wt_view["split"] == "val"].reset_index(drop=True)
656
+ smi_train = smi_view[smi_view["split"] == "train"].reset_index(drop=True)
657
+ smi_val = smi_view[smi_view["split"] == "val"].reset_index(drop=True)
658
+
659
+
660
+ # =========================
661
+ # TARGET pooled embeddings (ESM) — SEPARATE per branch
662
+ # =========================
663
+ with section("TARGET pooled embeddings (ESM) — WT + SMILES separately"):
664
+ wt_tok = AutoTokenizer.from_pretrained(WT_MODEL_NAME)
665
+ wt_esm = EsmModel.from_pretrained(WT_MODEL_NAME).to(DEVICE).eval()
666
+
667
+ # ---- WT targets ----
668
+ wt_train_tgt_emb = wt_pooled_embeddings(
669
+ wt_train["target_sequence"].astype(str).str.strip().tolist(),
670
+ wt_tok, wt_esm,
671
+ batch_size=WT_BATCH,
672
+ max_length=WT_MAX_LEN,
673
+ ).astype(np.float32)
674
+
675
+ wt_val_tgt_emb = wt_pooled_embeddings(
676
+ wt_val["target_sequence"].astype(str).str.strip().tolist(),
677
+ wt_tok, wt_esm,
678
+ batch_size=WT_BATCH,
679
+ max_length=WT_MAX_LEN,
680
+ ).astype(np.float32)
681
+
682
+ # ---- SMILES targets (independent; may include UAA-only targets) ----
683
+ smi_train_tgt_emb = wt_pooled_embeddings(
684
+ smi_train["target_sequence"].astype(str).str.strip().tolist(),
685
+ wt_tok, wt_esm,
686
+ batch_size=WT_BATCH,
687
+ max_length=WT_MAX_LEN,
688
+ ).astype(np.float32)
689
+
690
+ smi_val_tgt_emb = wt_pooled_embeddings(
691
+ smi_val["target_sequence"].astype(str).str.strip().tolist(),
692
+ wt_tok, wt_esm,
693
+ batch_size=WT_BATCH,
694
+ max_length=WT_MAX_LEN,
695
+ ).astype(np.float32)
696
+
697
+
698
+ # =========================
699
+ # WT pooled binder embeddings (binder = WT peptide)
700
+ # =========================
701
+ with section("WT pooled binder embeddings + save"):
702
+ wt_train_emb = wt_pooled_embeddings(
703
+ wt_train["sequence"].astype(str).str.strip().tolist(),
704
+ wt_tok, wt_esm,
705
+ batch_size=WT_BATCH,
706
+ max_length=WT_MAX_LEN,
707
+ ).astype(np.float32)
708
+
709
+ wt_val_emb = wt_pooled_embeddings(
710
+ wt_val["sequence"].astype(str).str.strip().tolist(),
711
+ wt_tok, wt_esm,
712
+ batch_size=WT_BATCH,
713
+ max_length=WT_MAX_LEN,
714
+ ).astype(np.float32)
715
+
716
+ wt_train_ds = Dataset.from_dict({
717
+ "target_sequence": wt_train["target_sequence"].tolist(),
718
+ "sequence": wt_train["sequence"].tolist(),
719
+ "label": wt_train["label"].astype(float).tolist(),
720
+ "target_embedding": wt_train_tgt_emb,
721
+ "embedding": wt_train_emb,
722
+ COL_WT_IPTM: wt_train[COL_WT_IPTM].astype(float).tolist(),
723
+ COL_AFF: wt_train[COL_AFF].astype(float).tolist(),
724
+ "affinity_class": wt_train["affinity_class"].tolist(),
725
+ })
726
+
727
+ wt_val_ds = Dataset.from_dict({
728
+ "target_sequence": wt_val["target_sequence"].tolist(),
729
+ "sequence": wt_val["sequence"].tolist(),
730
+ "label": wt_val["label"].astype(float).tolist(),
731
+ "target_embedding": wt_val_tgt_emb,
732
+ "embedding": wt_val_emb,
733
+ COL_WT_IPTM: wt_val[COL_WT_IPTM].astype(float).tolist(),
734
+ COL_AFF: wt_val[COL_AFF].astype(float).tolist(),
735
+ "affinity_class": wt_val["affinity_class"].tolist(),
736
+ })
737
+
738
+ wt_pooled_dd = DatasetDict({"train": wt_train_ds, "val": wt_val_ds})
739
+ wt_pooled_out = OUT_ROOT / "pair_wt_wt_pooled"
740
+ wt_pooled_dd.save_to_disk(str(wt_pooled_out))
741
+ log(f"Saved WT pooled -> {wt_pooled_out}")
742
+
743
+
744
+ # =========================
745
+ # SMILES pooled binder embeddings (binder = SMILES via PeptideCLM)
746
+ # =========================
747
+ with section("SMILES pooled binder embeddings + save"):
748
+ smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
749
+ smi_roformer = (
750
+ AutoModelForMaskedLM
751
+ .from_pretrained(SMI_MODEL_NAME)
752
+ .roformer
753
+ .to(DEVICE)
754
+ .eval()
755
+ )
756
+
757
+ smi_train_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
758
+ smi_train["sequence"].astype(str).str.strip().tolist(),
759
+ smi_tok, smi_roformer,
760
+ batch_size=SMI_BATCH,
761
+ max_length=SMI_MAX_LEN,
762
+ )
763
+
764
+ smi_val_pooled, _, _, _ = smiles_generate_embeddings_batched_both(
765
+ smi_val["sequence"].astype(str).str.strip().tolist(),
766
+ smi_tok, smi_roformer,
767
+ batch_size=SMI_BATCH,
768
+ max_length=SMI_MAX_LEN,
769
+ )
770
+
771
+ smi_train_ds = Dataset.from_dict({
772
+ "target_sequence": smi_train["target_sequence"].tolist(),
773
+ "sequence": smi_train["sequence"].tolist(),
774
+ "label": smi_train["label"].astype(float).tolist(),
775
+ "target_embedding": smi_train_tgt_emb,
776
+ "embedding": smi_train_pooled.astype(np.float32),
777
+ COL_SMI_IPTM: smi_train[COL_SMI_IPTM].astype(float).tolist(),
778
+ COL_AFF: smi_train[COL_AFF].astype(float).tolist(),
779
+ "affinity_class": smi_train["affinity_class"].tolist(),
780
+ })
781
+
782
+ smi_val_ds = Dataset.from_dict({
783
+ "target_sequence": smi_val["target_sequence"].tolist(),
784
+ "sequence": smi_val["sequence"].tolist(),
785
+ "label": smi_val["label"].astype(float).tolist(),
786
+ "target_embedding": smi_val_tgt_emb,
787
+ "embedding": smi_val_pooled.astype(np.float32),
788
+ COL_SMI_IPTM: smi_val[COL_SMI_IPTM].astype(float).tolist(),
789
+ COL_AFF: smi_val[COL_AFF].astype(float).tolist(),
790
+ "affinity_class": smi_val["affinity_class"].tolist(),
791
+ })
792
+
793
+ smi_pooled_dd = DatasetDict({"train": smi_train_ds, "val": smi_val_ds})
794
+ smi_pooled_out = OUT_ROOT / "pair_wt_smiles_pooled"
795
+ smi_pooled_dd.save_to_disk(str(smi_pooled_out))
796
+ log(f"Saved SMILES pooled -> {smi_pooled_out}")
797
+
798
+
799
+ # =========================
800
+ # WT unpooled paired (ESM target + ESM binder) + save
801
+ # =========================
802
+ with section("WT unpooled paired embeddings + save"):
803
+ wt_tok_unpooled = wt_tok # reuse tokenizer
804
+ wt_esm_unpooled = wt_esm # reuse model
805
+
806
+ wt_unpooled_out = OUT_ROOT / "pair_wt_wt_unpooled"
807
+ wt_unpooled_dd = DatasetDict({
808
+ "train": build_wt_unpooled_dataset(wt_train, wt_unpooled_out / "train",
809
+ wt_tok_unpooled, wt_esm_unpooled),
810
+ "val": build_wt_unpooled_dataset(wt_val, wt_unpooled_out / "val",
811
+ wt_tok_unpooled, wt_esm_unpooled),
812
+ })
813
+ # (Optional) also save as DatasetDict root if you want a single load_from_disk path:
814
+ wt_unpooled_dd.save_to_disk(str(wt_unpooled_out))
815
+ log(f"Saved WT unpooled -> {wt_unpooled_out}")
816
+
817
+
818
+ # =========================
819
+ # SMILES unpooled paired (ESM target + PeptideCLM binder) + save
820
+ # =========================
821
+ with section("SMILES unpooled paired embeddings + save"):
822
+ # reuse already-loaded smi_tok/smi_roformer from pooled section if still in scope;
823
+ # otherwise re-init here:
824
+ # smi_tok = SMILES_SPE_Tokenizer(TOKENIZER_VOCAB, TOKENIZER_SPLITS)
825
+ # smi_roformer = AutoModelForMaskedLM.from_pretrained(SMI_MODEL_NAME).roformer.to(DEVICE).eval()
826
+
827
+ smi_unpooled_out = OUT_ROOT / "pair_wt_smiles_unpooled"
828
+ smi_unpooled_dd = DatasetDict({
829
+ "train": build_smiles_unpooled_paired_dataset(
830
+ smi_train, smi_unpooled_out / "train",
831
+ wt_tok, wt_esm,
832
+ smi_tok, smi_roformer
833
+ ),
834
+ "val": build_smiles_unpooled_paired_dataset(
835
+ smi_val, smi_unpooled_out / "val",
836
+ wt_tok, wt_esm,
837
+ smi_tok, smi_roformer
838
+ ),
839
+ })
840
+ smi_unpooled_dd.save_to_disk(str(smi_unpooled_out))
841
+ log(f"Saved SMILES unpooled -> {smi_unpooled_out}")
842
+
843
+ log(f"\n[DONE] All datasets saved under: {OUT_ROOT}")
844
+
845
+
846
+ if __name__ == "__main__":
847
+ main()
training_classifiers/binding_training.py ADDED
@@ -0,0 +1,414 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os, json
2
+ from pathlib import Path
3
+ import numpy as np
4
+ import torch
5
+ import torch.nn as nn
6
+ from torch.utils.data import DataLoader
7
+ import optuna
8
+ from datasets import load_from_disk, DatasetDict
9
+ from scipy.stats import spearmanr
10
+ from lightning.pytorch import seed_everything
11
+ seed_everything(1986)
12
+
13
+ DEVICE = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
14
+
15
+
16
+ def safe_spearmanr(y_true: np.ndarray, y_pred: np.ndarray) -> float:
17
+ rho = spearmanr(y_true, y_pred).correlation
18
+ if rho is None or np.isnan(rho):
19
+ return 0.0
20
+ return float(rho)
21
+
22
+
23
+ # -----------------------------
24
+ # Affinity class thresholds (final spec)
25
+ # High >= 9 ; Moderate 7-9 ; Low < 7
26
+ # 0=High, 1=Moderate, 2=Low
27
+ # -----------------------------
28
+ def affinity_to_class_tensor(y: torch.Tensor) -> torch.Tensor:
29
+ high = y >= 9.0
30
+ low = y < 7.0
31
+ mid = ~(high | low)
32
+ cls = torch.zeros_like(y, dtype=torch.long)
33
+ cls[mid] = 1
34
+ cls[low] = 2
35
+ return cls
36
+
37
+
38
+ # -----------------------------
39
+ # Load paired DatasetDict
40
+ # -----------------------------
41
+ def load_split_paired(path: str):
42
+ dd = load_from_disk(path)
43
+ if not isinstance(dd, DatasetDict):
44
+ raise ValueError(f"Expected DatasetDict at {path}")
45
+ if "train" not in dd or "val" not in dd:
46
+ raise ValueError(f"DatasetDict missing train/val at {path}")
47
+ return dd["train"], dd["val"]
48
+
49
+
50
+ # -----------------------------
51
+ # Collate: pooled paired
52
+ # -----------------------------
53
+ def collate_pair_pooled(batch):
54
+ Pt = torch.tensor([x["target_embedding"] for x in batch], dtype=torch.float32) # (B,Ht)
55
+ Pb = torch.tensor([x["binder_embedding"] for x in batch], dtype=torch.float32) # (B,Hb)
56
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
57
+ return Pt, Pb, y
58
+
59
+
60
+ # -----------------------------
61
+ # Collate: unpooled paired
62
+ # -----------------------------
63
+ def collate_pair_unpooled(batch):
64
+ B = len(batch)
65
+ Ht = len(batch[0]["target_embedding"][0])
66
+ Hb = len(batch[0]["binder_embedding"][0])
67
+ Lt_max = max(int(x["target_length"]) for x in batch)
68
+ Lb_max = max(int(x["binder_length"]) for x in batch)
69
+
70
+ Pt = torch.zeros(B, Lt_max, Ht, dtype=torch.float32)
71
+ Pb = torch.zeros(B, Lb_max, Hb, dtype=torch.float32)
72
+ Mt = torch.zeros(B, Lt_max, dtype=torch.bool)
73
+ Mb = torch.zeros(B, Lb_max, dtype=torch.bool)
74
+ y = torch.tensor([float(x["label"]) for x in batch], dtype=torch.float32)
75
+
76
+ for i, x in enumerate(batch):
77
+ t = torch.tensor(x["target_embedding"], dtype=torch.float32)
78
+ b = torch.tensor(x["binder_embedding"], dtype=torch.float32)
79
+ lt, lb = t.shape[0], b.shape[0]
80
+ Pt[i, :lt] = t
81
+ Pb[i, :lb] = b
82
+ Mt[i, :lt] = torch.tensor(x["target_attention_mask"][:lt], dtype=torch.bool)
83
+ Mb[i, :lb] = torch.tensor(x["binder_attention_mask"][:lb], dtype=torch.bool)
84
+
85
+ return Pt, Mt, Pb, Mb, y
86
+
87
+
88
+ # -----------------------------
89
+ # Cross-attention models
90
+ # -----------------------------
91
+ class CrossAttnPooled(nn.Module):
92
+ """
93
+ pooled vectors -> treat as single-token sequences for cross attention
94
+ """
95
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
96
+ super().__init__()
97
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
98
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
99
+
100
+ self.layers = nn.ModuleList([])
101
+ for _ in range(n_layers):
102
+ self.layers.append(nn.ModuleDict({
103
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
104
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=False),
105
+ "n1t": nn.LayerNorm(hidden),
106
+ "n2t": nn.LayerNorm(hidden),
107
+ "n1b": nn.LayerNorm(hidden),
108
+ "n2b": nn.LayerNorm(hidden),
109
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
110
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
111
+ }))
112
+
113
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
114
+ self.reg = nn.Linear(hidden, 1)
115
+ self.cls = nn.Linear(hidden, 3)
116
+
117
+ def forward(self, t_vec, b_vec):
118
+ # (B,Ht),(B,Hb)
119
+ t = self.t_proj(t_vec).unsqueeze(0) # (1,B,H)
120
+ b = self.b_proj(b_vec).unsqueeze(0) # (1,B,H)
121
+
122
+ for L in self.layers:
123
+ t_attn, _ = L["attn_tb"](t, b, b)
124
+ t = L["n1t"]((t + t_attn).transpose(0,1)).transpose(0,1)
125
+ t = L["n2t"]((t + L["fft"](t)).transpose(0,1)).transpose(0,1)
126
+
127
+ b_attn, _ = L["attn_bt"](b, t, t)
128
+ b = L["n1b"]((b + b_attn).transpose(0,1)).transpose(0,1)
129
+ b = L["n2b"]((b + L["ffb"](b)).transpose(0,1)).transpose(0,1)
130
+
131
+ t0 = t[0]
132
+ b0 = b[0]
133
+ z = torch.cat([t0, b0], dim=-1)
134
+ h = self.shared(z)
135
+ return self.reg(h).squeeze(-1), self.cls(h)
136
+
137
+
138
+ class CrossAttnUnpooled(nn.Module):
139
+ """
140
+ token sequences with masks; alternating cross attention.
141
+ """
142
+ def __init__(self, Ht, Hb, hidden=512, n_heads=8, n_layers=3, dropout=0.1):
143
+ super().__init__()
144
+ self.t_proj = nn.Sequential(nn.Linear(Ht, hidden), nn.LayerNorm(hidden))
145
+ self.b_proj = nn.Sequential(nn.Linear(Hb, hidden), nn.LayerNorm(hidden))
146
+
147
+ self.layers = nn.ModuleList([])
148
+ for _ in range(n_layers):
149
+ self.layers.append(nn.ModuleDict({
150
+ "attn_tb": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
151
+ "attn_bt": nn.MultiheadAttention(hidden, n_heads, dropout=dropout, batch_first=True),
152
+ "n1t": nn.LayerNorm(hidden),
153
+ "n2t": nn.LayerNorm(hidden),
154
+ "n1b": nn.LayerNorm(hidden),
155
+ "n2b": nn.LayerNorm(hidden),
156
+ "fft": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
157
+ "ffb": nn.Sequential(nn.Linear(hidden, 4*hidden), nn.GELU(), nn.Dropout(dropout), nn.Linear(4*hidden, hidden)),
158
+ }))
159
+
160
+ self.shared = nn.Sequential(nn.Linear(2*hidden, hidden), nn.GELU(), nn.Dropout(dropout))
161
+ self.reg = nn.Linear(hidden, 1)
162
+ self.cls = nn.Linear(hidden, 3)
163
+
164
+ def masked_mean(self, X, M):
165
+ Mf = M.unsqueeze(-1).float()
166
+ denom = Mf.sum(dim=1).clamp(min=1.0)
167
+ return (X * Mf).sum(dim=1) / denom
168
+
169
+ def forward(self, T, Mt, B, Mb):
170
+ # T:(B,Lt,Ht), Mt:(B,Lt) ; B:(B,Lb,Hb), Mb:(B,Lb)
171
+ T = self.t_proj(T)
172
+ Bx = self.b_proj(B)
173
+
174
+ kp_t = ~Mt # key_padding_mask True = pad
175
+ kp_b = ~Mb
176
+
177
+ for L in self.layers:
178
+ # T attends to B
179
+ T_attn, _ = L["attn_tb"](T, Bx, Bx, key_padding_mask=kp_b)
180
+ T = L["n1t"](T + T_attn)
181
+ T = L["n2t"](T + L["fft"](T))
182
+
183
+ # B attends to T
184
+ B_attn, _ = L["attn_bt"](Bx, T, T, key_padding_mask=kp_t)
185
+ Bx = L["n1b"](Bx + B_attn)
186
+ Bx = L["n2b"](Bx + L["ffb"](Bx))
187
+
188
+ t_pool = self.masked_mean(T, Mt)
189
+ b_pool = self.masked_mean(Bx, Mb)
190
+ z = torch.cat([t_pool, b_pool], dim=-1)
191
+ h = self.shared(z)
192
+ return self.reg(h).squeeze(-1), self.cls(h)
193
+
194
+
195
+ # -----------------------------
196
+ # Train/eval
197
+ # -----------------------------
198
+ @torch.no_grad()
199
+ def eval_spearman_pooled(model, loader):
200
+ model.eval()
201
+ ys, ps = [], []
202
+ for t, b, y in loader:
203
+ t = t.to(DEVICE, non_blocking=True)
204
+ b = b.to(DEVICE, non_blocking=True)
205
+ pred, _ = model(t, b)
206
+ ys.append(y.numpy())
207
+ ps.append(pred.detach().cpu().numpy())
208
+ return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
209
+
210
+ @torch.no_grad()
211
+ def eval_spearman_unpooled(model, loader):
212
+ model.eval()
213
+ ys, ps = [], []
214
+ for T, Mt, B, Mb, y in loader:
215
+ T = T.to(DEVICE, non_blocking=True)
216
+ Mt = Mt.to(DEVICE, non_blocking=True)
217
+ B = B.to(DEVICE, non_blocking=True)
218
+ Mb = Mb.to(DEVICE, non_blocking=True)
219
+ pred, _ = model(T, Mt, B, Mb)
220
+ ys.append(y.numpy())
221
+ ps.append(pred.detach().cpu().numpy())
222
+ return safe_spearmanr(np.concatenate(ys), np.concatenate(ps))
223
+
224
+ def train_one_epoch_pooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
225
+ model.train()
226
+ for t, b, y in loader:
227
+ t = t.to(DEVICE, non_blocking=True)
228
+ b = b.to(DEVICE, non_blocking=True)
229
+ y = y.to(DEVICE, non_blocking=True)
230
+ y_cls = affinity_to_class_tensor(y)
231
+
232
+ opt.zero_grad(set_to_none=True)
233
+ pred, logits = model(t, b)
234
+ L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
235
+ L.backward()
236
+ if clip is not None:
237
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
238
+ opt.step()
239
+
240
+ def train_one_epoch_unpooled(model, loader, opt, loss_reg, loss_cls, cls_w=1.0, clip=1.0):
241
+ model.train()
242
+ for T, Mt, B, Mb, y in loader:
243
+ T = T.to(DEVICE, non_blocking=True)
244
+ Mt = Mt.to(DEVICE, non_blocking=True)
245
+ B = B.to(DEVICE, non_blocking=True)
246
+ Mb = Mb.to(DEVICE, non_blocking=True)
247
+ y = y.to(DEVICE, non_blocking=True)
248
+ y_cls = affinity_to_class_tensor(y)
249
+
250
+ opt.zero_grad(set_to_none=True)
251
+ pred, logits = model(T, Mt, B, Mb)
252
+ L = loss_reg(pred, y) + cls_w * loss_cls(logits, y_cls)
253
+ L.backward()
254
+ if clip is not None:
255
+ torch.nn.utils.clip_grad_norm_(model.parameters(), clip)
256
+ opt.step()
257
+
258
+
259
+ # -----------------------------
260
+ # Optuna objective
261
+ # -----------------------------
262
+ def objective_crossattn(trial: optuna.Trial, mode: str, train_ds, val_ds) -> float:
263
+ lr = trial.suggest_float("lr", 1e-5, 3e-3, log=True)
264
+ wd = trial.suggest_float("weight_decay", 1e-10, 1e-2, log=True)
265
+ dropout = trial.suggest_float("dropout", 0.0, 0.4)
266
+ hidden = trial.suggest_categorical("hidden_dim", [256, 384, 512, 768])
267
+ n_heads = trial.suggest_categorical("n_heads", [4, 8])
268
+ n_layers = trial.suggest_int("n_layers", 1, 4)
269
+ cls_w = trial.suggest_float("cls_weight", 0.1, 2.0, log=True)
270
+ batch = trial.suggest_categorical("batch_size", [16, 32, 64, 128])
271
+
272
+ # infer dims from first row
273
+ if mode == "pooled":
274
+ Ht = len(train_ds[0]["target_embedding"])
275
+ Hb = len(train_ds[0]["binder_embedding"])
276
+ collate = collate_pair_pooled
277
+ model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
278
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
279
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
280
+ eval_fn = eval_spearman_pooled
281
+ train_fn = train_one_epoch_pooled
282
+
283
+ else:
284
+ Ht = len(train_ds[0]["target_embedding"][0])
285
+ Hb = len(train_ds[0]["binder_embedding"][0])
286
+ collate = collate_pair_unpooled
287
+ model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
288
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
289
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
290
+ eval_fn = eval_spearman_unpooled
291
+ train_fn = train_one_epoch_unpooled
292
+
293
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
294
+ loss_reg = nn.MSELoss()
295
+ loss_cls = nn.CrossEntropyLoss()
296
+
297
+ best = -1e9
298
+ bad = 0
299
+ patience = 10
300
+
301
+ for ep in range(1, 61):
302
+ train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
303
+ rho = eval_fn(model, val_loader)
304
+
305
+ trial.report(rho, ep)
306
+ if trial.should_prune():
307
+ raise optuna.TrialPruned()
308
+
309
+ if rho > best + 1e-6:
310
+ best = rho
311
+ bad = 0
312
+ else:
313
+ bad += 1
314
+ if bad >= patience:
315
+ break
316
+
317
+ return float(best)
318
+
319
+
320
+ # -----------------------------
321
+ # Run: optuna + refit best
322
+ # -----------------------------
323
+ def run(dataset_path: str, out_dir: str, mode: str, n_trials: int = 50):
324
+ out_dir = Path(out_dir)
325
+ out_dir.mkdir(parents=True, exist_ok=True)
326
+
327
+ train_ds, val_ds = load_split_paired(dataset_path)
328
+ print(f"[Data] Train={len(train_ds)} Val={len(val_ds)} | mode={mode}")
329
+
330
+ study = optuna.create_study(direction="maximize", pruner=optuna.pruners.MedianPruner())
331
+ study.optimize(lambda t: objective_crossattn(t, mode, train_ds, val_ds), n_trials=n_trials)
332
+
333
+ study.trials_dataframe().to_csv(out_dir / "optuna_trials.csv", index=False)
334
+ best = study.best_trial
335
+ best_params = dict(best.params)
336
+
337
+ # refit longer
338
+ lr = float(best_params["lr"])
339
+ wd = float(best_params["weight_decay"])
340
+ dropout = float(best_params["dropout"])
341
+ hidden = int(best_params["hidden_dim"])
342
+ n_heads = int(best_params["n_heads"])
343
+ n_layers = int(best_params["n_layers"])
344
+ cls_w = float(best_params["cls_weight"])
345
+ batch = int(best_params["batch_size"])
346
+
347
+ loss_reg = nn.MSELoss()
348
+ loss_cls = nn.CrossEntropyLoss()
349
+
350
+ if mode == "pooled":
351
+ Ht = len(train_ds[0]["target_embedding"])
352
+ Hb = len(train_ds[0]["binder_embedding"])
353
+ model = CrossAttnPooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
354
+ collate = collate_pair_pooled
355
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
356
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
357
+ eval_fn = eval_spearman_pooled
358
+ train_fn = train_one_epoch_pooled
359
+ else:
360
+ Ht = len(train_ds[0]["target_embedding"][0])
361
+ Hb = len(train_ds[0]["binder_embedding"][0])
362
+ model = CrossAttnUnpooled(Ht, Hb, hidden=hidden, n_heads=n_heads, n_layers=n_layers, dropout=dropout).to(DEVICE)
363
+ collate = collate_pair_unpooled
364
+ train_loader = DataLoader(train_ds, batch_size=batch, shuffle=True, num_workers=4, pin_memory=True, collate_fn=collate)
365
+ val_loader = DataLoader(val_ds, batch_size=batch, shuffle=False, num_workers=4, pin_memory=True, collate_fn=collate)
366
+ eval_fn = eval_spearman_unpooled
367
+ train_fn = train_one_epoch_unpooled
368
+
369
+ opt = torch.optim.AdamW(model.parameters(), lr=lr, weight_decay=wd)
370
+
371
+ best_rho = -1e9
372
+ bad = 0
373
+ patience = 20
374
+ best_state = None
375
+
376
+ for ep in range(1, 201):
377
+ train_fn(model, train_loader, opt, loss_reg, loss_cls, cls_w=cls_w)
378
+ rho = eval_fn(model, val_loader)
379
+
380
+ if rho > best_rho + 1e-6:
381
+ best_rho = rho
382
+ bad = 0
383
+ best_state = {k: v.detach().cpu().clone() for k, v in model.state_dict().items()}
384
+ else:
385
+ bad += 1
386
+ if bad >= patience:
387
+ break
388
+
389
+ if best_state is not None:
390
+ model.load_state_dict(best_state)
391
+
392
+ # save
393
+ torch.save({"mode": mode, "best_params": best_params, "state_dict": model.state_dict()}, out_dir / "best_model.pt")
394
+ with open(out_dir / "best_params.json", "w") as f:
395
+ json.dump(best_params, f, indent=2)
396
+
397
+ print(f"[DONE] {out_dir} | best_optuna_rho={study.best_value:.4f} | refit_best_rho={best_rho:.4f}")
398
+
399
+
400
+ if __name__ == "__main__":
401
+ import argparse
402
+ ap = argparse.ArgumentParser()
403
+ ap.add_argument("--dataset_path", type=str, required=True, help="Paired DatasetDict path (pair_*)")
404
+ ap.add_argument("--mode", type=str, choices=["pooled", "unpooled"], required=True)
405
+ ap.add_argument("--out_dir", type=str, required=True)
406
+ ap.add_argument("--n_trials", type=int, default=50)
407
+ args = ap.parse_args()
408
+
409
+ run(
410
+ dataset_path=args.dataset_path,
411
+ out_dir=args.out_dir,
412
+ mode=args.mode,
413
+ n_trials=args.n_trials,
414
+ )
training_classifiers/binding_wt.bash ADDED
@@ -0,0 +1,31 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/bin/bash
2
+ #SBATCH --job-name=b-data
3
+ #SBATCH --partition=dgx-b200
4
+ #SBATCH --gpus=1
5
+ #SBATCH --cpus-per-task=10
6
+ #SBATCH --mem=100G
7
+ #SBATCH --time=48:00:00
8
+ #SBATCH --output=%x_%j.out
9
+
10
+ HOME_LOC=/vast/projects/pranam/lab/yz927
11
+ SCRIPT_LOC=$HOME_LOC/projects/Classifier_Weight/training_classifiers
12
+ DATA_LOC=$HOME_LOC/projects/Classifier_Weight/training_data_cleaned
13
+ OBJECTIVE='binding_affinity'
14
+ WT='smiles' #wt/smiles
15
+ STATUS='pooled' #pooled/unpooled
16
+ DATA_FILE="pair_wt_${WT}_${STATUS}"
17
+ LOG_LOC=$SCRIPT_LOC
18
+ DATE=$(date +%m_%d)
19
+ SPECIAL_PREFIX="binding_affinity_data_generation"
20
+
21
+ # Create log directory if it doesn't exist
22
+ mkdir -p $LOG_LOC
23
+
24
+ cd $SCRIPT_LOC
25
+ source /vast/projects/pranam/lab/shared/miniconda3/etc/profile.d/conda.sh
26
+ conda activate /vast/projects/pranam/lab/shared/miniconda3/envs/metal
27
+
28
+ python -u binding_affinity_split.py > "${LOG_LOC}/${DATE}_${SPECIAL_PREFIX}.log" 2>&1
29
+
30
+ echo "Script completed at $(date)"
31
+ conda deactivate
training_classifiers/hemolysis/cnn_smiles/best_model.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:dd8f379ef2a10dacff4236ca37aa64832a3ce8bc9608ca1297b1b7662780ee6f
3
+ size 14170677
training_classifiers/hemolysis/cnn_smiles/best_model_benchmark.json ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "n_samples": 800,
3
+ "wall_time_s": 6.477548930000921,
4
+ "throughput_samples_per_s": 123.50350551499211,
5
+ "gpu_total_kernel_ms": 39.01705604791641,
6
+ "gpu_ms_per_sample": 0.04877132005989551,
7
+ "gpu_avg_ms_per_batch": 0.7803411209583282,
8
+ "gpu_peak_mem_MB": 219.23974609375,
9
+ "telemetry_pre": {
10
+ "cpu_freq_current_MHz": 1357.0153303571428,
11
+ "cpu_freq_max_MHz": 4000.0,
12
+ "cpu_util_pct": 5.7,
13
+ "cpu_count_logical": 224,
14
+ "cpu_count_physical": 112,
15
+ "gpu_util_pct": 0,
16
+ "gpu_mem_util_pct": 0,
17
+ "gpu_mem_used_MB": 1829.0625,
18
+ "gpu_mem_total_MB": 183359.0,
19
+ "gpu_sm_clock_MHz": 1965,
20
+ "gpu_mem_clock_MHz": 3996,
21
+ "gpu_power_W": 194.631,
22
+ "gpu_temp_C": 31
23
+ },
24
+ "telemetry_post": {
25
+ "cpu_freq_current_MHz": 1529.574044642856,
26
+ "cpu_freq_max_MHz": 4000.0,
27
+ "cpu_util_pct": 6.0,
28
+ "cpu_count_logical": 224,
29
+ "cpu_count_physical": 112,
30
+ "gpu_util_pct": 0,
31
+ "gpu_mem_util_pct": 0,
32
+ "gpu_mem_used_MB": 1923.0625,
33
+ "gpu_mem_total_MB": 183359.0,
34
+ "gpu_sm_clock_MHz": 1965,
35
+ "gpu_mem_clock_MHz": 3996,
36
+ "gpu_power_W": 197.518,
37
+ "gpu_temp_C": 31
38
+ }
39
+ }
training_classifiers/hemolysis/cnn_smiles/optimization_summary.txt ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ========================================================================
2
+ MODEL: cnn
3
+ Best Optuna F1 (objective): 0.5290
4
+ Best Optuna AUC (val, recorded): 0.7477
5
+ Best Optuna threshold (val): 0.4518
6
+ Refit best AUC (val): 0.7851
7
+ Refit best F1@thr (val): 0.5366 at thr=0.5298
8
+ Best params:
9
+ {
10
+ "lr": 0.0002237456677696451,
11
+ "weight_decay": 0.0005722918417016266,
12
+ "dropout": 0.2697397384794115,
13
+ "batch_size": 16,
14
+ "channels": 512,
15
+ "kernel": 3,
16
+ "layers": 4
17
+ }
18
+ Saved model: /vast/projects/pranam/lab/yz927/projects/Classifier_Weight/training_classifiers/hemolysis/cnn_smiles/best_model.pt
19
+ ========================================================================
training_classifiers/hemolysis/cnn_smiles/pr_curve.png ADDED
training_classifiers/hemolysis/cnn_smiles/roc_curve.png ADDED
training_classifiers/hemolysis/cnn_smiles/study_trials.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:341f45fdbd7565793d9ff64a3d1be6ebd6d56d6e6957db19d119a946c658d296
3
+ size 48177
training_classifiers/hemolysis/cnn_smiles/train_predictions.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:44ee47243c96b0d372586ee8c40582af2aa086cfb626c5826b05778f1f08b919
3
+ size 1943431
training_classifiers/hemolysis/cnn_smiles/val_predictions.csv ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:df88e6577184fbf56018945e0540c5d5679981ff74af08422b03cd6aab43d5e3
3
+ size 472104